`__ .
+ Args:
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
+ The tokenizer used for encoding the data
+ mask_ratio (:obj:`float`):
+ The probability with which to (randomly) mask tokens in the input
+ poisson_lambda (:obj:`float`):
+ Mean parameter of Poisson distribution used to generate span-lengths to be masked
+ permute_sentence_ratio (:obj:`float`):
+ Ratio of sentences to be permuted in each document
+ decoder_start_token_id: (:obj:`int):
+ The decoder start token id of the model
+ """
+
+ tokenizer: PreTrainedTokenizerBase
+ decoder_start_token_id: int
+ mask_ratio: float = 0.3
+ poisson_lambda: float = 3.0
+ permute_sentence_ratio: float = 1.0
+
+ def __post_init__(self):
+ if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None:
+ raise ValueError(
+ "This tokenizer does not have a mask token or eos token token which is necessary for denoising"
+ " language modeling. "
+ )
+
+ def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
+ # convert list to dict and tensorize input
+ batch = BatchEncoding(
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
+ )
+ batch["labels"] = batch["input_ids"].copy()
+ batch["decoder_input_ids"] = shift_tokens_right(
+ batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
+ )
+ # permuting sentences
+ do_permute = False
+ if self.permute_sentence_ratio > 0.0:
+ batch["input_ids"] = self.permute_sentences(batch["input_ids"])
+ do_permute = True
+
+ # masking span of tokens (text infilling in the paper)
+ if self.mask_ratio:
+ batch["input_ids"], batch["labels"] = self.span_mask_tokens(
+ batch["input_ids"], batch["labels"], do_permute
+ )
+
+ # ignore pad tokens
+ batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).astype(int)
+ batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).astype(int)
+ return batch
+
+ def permute_sentences(self, input_ids):
+ """
+ Shuffle sentences in each document.
+ """
+ results = input_ids.copy()
+
+ # find end locations of sentences
+ end_sentence_mask = input_ids == self.tokenizer.pad_token_id
+ sentence_ends = np.argwhere(end_sentence_mask)
+ sentence_ends[:, 1] += 1
+ example_has_multiple_sentences, num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)
+ num_sentences_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_sentences)}
+
+ num_to_permute = np.ceil(num_sentences * self.permute_sentence_ratio).astype(int)
+ num_to_permute_map = {
+ sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_to_permute)
+ }
+
+ sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
+ sentence_ends_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, sentence_ends)}
+
+ for i in range(input_ids.shape[0]):
+ if i not in example_has_multiple_sentences:
+ continue
+ substitutions = np.random.permutation(num_sentences_map[i])[: num_to_permute_map[i]]
+ ordering = np.arange(0, num_sentences_map[i])
+ ordering[substitutions] = substitutions[np.random.permutation(num_to_permute_map[i])]
+
+ # write shuffled sentences into results
+ index = 0
+ for j in ordering:
+ sentence = input_ids[i, (sentence_ends_map[i][j - 1] if j > 0 else 0) : sentence_ends_map[i][j]]
+ results[i, index : index + sentence.shape[0]] = sentence
+ index += sentence.shape[0]
+ return results
+
+ def span_mask_tokens(self, input_ids, labels, do_permute):
+ """
+ Sampling text spans with span lengths drawn from a Poisson distribution and masking them.
+ """
+ special_tokens_mask_labels = [
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
+ ]
+ special_tokens_mask_inputs = [
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in input_ids.tolist()
+ ]
+ special_tokens_mask_labels = np.array(special_tokens_mask_labels, dtype=bool)
+ special_tokens_mask_inputs = np.array(special_tokens_mask_inputs, dtype=bool)
+
+ # determine how many tokens we need to mask in total
+ is_token_mask = ~(input_ids == self.tokenizer.pad_token_id) & ~special_tokens_mask_inputs
+ num_tokens_to_mask = int(math.ceil(is_token_mask.astype(float).sum() * self.mask_ratio))
+ if num_tokens_to_mask == 0:
+ return input_ids, labels
+
+ # generate a sufficient number of span lengths
+ span_lengths = np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))
+ while np.cumsum(span_lengths, 0)[-1] < num_tokens_to_mask:
+ span_lengths = np.concatenate(
+ [span_lengths, np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))]
+ )
+
+ # remove all spans of length 0
+ # note that BART inserts additional mask tokens where length == 0,
+ # which we do not implement for now as it adds additional complexity
+ span_lengths = span_lengths[span_lengths > 0]
+
+ # trim to about num_tokens_to_mask tokens
+ cutoff_idx = np.argmin(np.abs(np.cumsum(span_lengths, 0) - num_tokens_to_mask)) + 1
+ span_lengths = span_lengths[:cutoff_idx]
+
+ # randomly choose starting positions for masking
+ token_indices = np.argwhere(is_token_mask == 1)
+ span_starts = np.random.permutation(token_indices.shape[0])[: span_lengths.shape[0]]
+ # prepare mask
+ masked_indices = np.array(token_indices[span_starts])
+ mask = np.full_like(input_ids, fill_value=False)
+
+ # mask starting positions
+ for mi in masked_indices:
+ mask[tuple(mi)] = True
+ span_lengths -= 1
+
+ # fill up spans
+ max_index = input_ids.shape[1] - 1
+ remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
+ while np.any(remaining):
+ masked_indices[remaining, 1] += 1
+ for mi in masked_indices:
+ mask[tuple(mi)] = True
+ span_lengths -= 1
+ remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
+
+ # place the mask tokens
+ mask[np.where(special_tokens_mask_inputs)] = False
+ input_ids[np.where(mask)] = self.tokenizer.mask_token_id
+ if not do_permute:
+ labels[np.where(mask == 0)] = -100
+ else:
+ labels[np.where(special_tokens_mask_labels)] = -100
+
+ # remove mask tokens that are not starts of spans
+ to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
+ new_input_ids = np.full_like(input_ids, fill_value=self.tokenizer.pad_token_id)
+ for i, example in enumerate(input_ids):
+ new_example = example[~to_remove[i]]
+ new_input_ids[i, : new_example.shape[0]] = new_example
+
+ return new_input_ids, labels
+
+
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
+ num_samples = len(samples_idx)
+ if drop_last:
+ samples_to_remove = num_samples % batch_size
+ if samples_to_remove != 0:
+ samples_idx = samples_idx[:-samples_to_remove]
+ sections_split = num_samples // batch_size
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
+ else:
+ sections_split = math.ceil(num_samples / batch_size)
+ samples_idx = np.array_split(samples_idx, sections_split)
+ return samples_idx
+
+
+def write_train_metric(summary_writer, train_metrics, train_time, step):
+ summary_writer.scalar("train_time", train_time, step)
+
+ train_metrics = get_metrics(train_metrics)
+ for key, vals in train_metrics.items():
+ tag = f"train_{key}"
+ for i, val in enumerate(vals):
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
+
+
+def write_eval_metric(summary_writer, eval_metrics, step):
+ for metric_name, value in eval_metrics.items():
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_bart_dlm", model_args, data_args, framework="flax")
+
+ if (
+ os.path.exists(training_args.output_dir)
+ and os.listdir(training_args.output_dir)
+ and training_args.do_train
+ and not training_args.overwrite_output_dir
+ ):
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
+ "Use --overwrite_output_dir to overcome."
+ )
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ level=logging.INFO,
+ datefmt="[%X]",
+ )
+
+ # Log on each process the small summary:
+ logger = logging.getLogger(__name__)
+
+ # Set the verbosity to info of the Transformers logger (on main process only):
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Handle the repository creation
+ if training_args.push_to_hub:
+ if training_args.hub_model_id is None:
+ repo_name = get_full_repo_name(
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
+ )
+ else:
+ repo_name = training_args.hub_model_id
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
+
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
+ # (the dataset will be downloaded automatically from the datasets Hub).
+ #
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
+ # 'text' is found. You can easily tweak this behavior (see below).
+ if data_args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ datasets = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ if "validation" not in datasets.keys():
+ datasets["validation"] = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ split=f"train[:{data_args.validation_split_percentage}%]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ datasets["train"] = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ split=f"train[{data_args.validation_split_percentage}%:]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ data_files = {}
+ if data_args.train_file is not None:
+ data_files["train"] = data_args.train_file
+ if data_args.validation_file is not None:
+ data_files["validation"] = data_args.validation_file
+ extension = data_args.train_file.split(".")[-1]
+ if extension == "txt":
+ extension = "text"
+ datasets = load_dataset(
+ extension,
+ data_files=data_files,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ if "validation" not in datasets.keys():
+ datasets["validation"] = load_dataset(
+ extension,
+ data_files=data_files,
+ split=f"train[:{data_args.validation_split_percentage}%]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ datasets["train"] = load_dataset(
+ extension,
+ data_files=data_files,
+ split=f"train[{data_args.validation_split_percentage}%:]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
+
+ # Load pretrained model and tokenizer
+
+ if model_args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.tokenizer_name,
+ cache_dir=model_args.cache_dir,
+ use_fast=model_args.use_fast_tokenizer,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ elif model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=model_args.use_fast_tokenizer,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ raise ValueError(
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
+ )
+
+ if model_args.config_name:
+ config = BartConfig.from_pretrained(
+ model_args.config_name,
+ cache_dir=model_args.cache_dir,
+ vocab_size=len(tokenizer),
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ elif model_args.model_name_or_path:
+ config = BartConfig.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ config = CONFIG_MAPPING[model_args.model_type]()
+ logger.warning("You are instantiating a new config instance from scratch.")
+
+ # Preprocessing the datasets.
+ # First we tokenize all the texts.
+ if training_args.do_train:
+ column_names = datasets["train"].column_names
+ else:
+ column_names = datasets["validation"].column_names
+ text_column_name = "text" if "text" in column_names else column_names[0]
+
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ # Use Punkt Sentence Tokenizer to divide a document into a list of sentences
+ nltk.download("punkt")
+ sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
+
+ def sentence_split_function(example):
+ sents = sentence_tokenizer.tokenize(example["text"])
+ # use pad token as end of sentence indicator
+ new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(sents) + tokenizer.eos_token
+ return {"text": new_text}
+
+ split_datasets = datasets.map(
+ sentence_split_function,
+ batched=False,
+ num_proc=data_args.preprocessing_num_workers,
+ remove_columns=column_names,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ # Tokenize every text, then concatenate them together before splitting them in smaller parts.
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
+ def tokenize_function(examples):
+ return tokenizer(examples[text_column_name], add_special_tokens=False, return_attention_mask=False)
+
+ tokenized_datasets = split_datasets.map(
+ tokenize_function,
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ remove_columns=text_column_name,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
+ # max_seq_length.
+ def group_texts(examples):
+ # Concatenate all texts.
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
+ # customize this part to your needs.
+ if total_length >= max_seq_length:
+ total_length = (total_length // max_seq_length) * max_seq_length
+ # Split by chunks of max_len.
+ result = {
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
+ for k, t in concatenated_examples.items()
+ }
+ return result
+
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
+ # might be slower to preprocess.
+ #
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
+ tokenized_datasets = tokenized_datasets.map(
+ group_texts,
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ # Enable tensorboard only on the master node
+ has_tensorboard = is_tensorboard_available()
+ if has_tensorboard and jax.process_index() == 0:
+ try:
+ from flax.metrics.tensorboard import SummaryWriter
+
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
+ except ImportError as ie:
+ has_tensorboard = False
+ logger.warning(
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
+ )
+ else:
+ logger.warning(
+ "Unable to display metrics through TensorBoard because the package is not installed: "
+ "Please run pip install tensorboard to enable."
+ )
+
+ # Initialize our training
+ rng = jax.random.PRNGKey(training_args.seed)
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
+
+ if model_args.model_name_or_path:
+ model = FlaxBartForConditionalGeneration.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ seed=training_args.seed,
+ dtype=getattr(jnp, model_args.dtype),
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ config.vocab_size = len(tokenizer)
+ model = FlaxBartForConditionalGeneration(
+ config,
+ seed=training_args.seed,
+ dtype=getattr(jnp, model_args.dtype),
+ )
+
+ # Data collator
+ # This one will take care of randomly masking the tokens and permuting the sentences.
+ data_collator = FlaxDataCollatorForBartDenoisingLM(
+ tokenizer=tokenizer,
+ decoder_start_token_id=model.config.decoder_start_token_id,
+ mask_ratio=data_args.mlm_probability,
+ poisson_lambda=data_args.poisson_lambda,
+ permute_sentence_ratio=data_args.permute_sentence_ratio,
+ )
+
+ # Store some constant
+ num_epochs = int(training_args.num_train_epochs)
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
+
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
+
+ # Create learning rate schedule
+ warmup_fn = optax.linear_schedule(
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
+ )
+ decay_fn = optax.linear_schedule(
+ init_value=training_args.learning_rate,
+ end_value=0,
+ transition_steps=num_train_steps - training_args.warmup_steps,
+ )
+ linear_decay_lr_schedule_fn = optax.join_schedules(
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
+ )
+
+ # We use Optax's "masking" functionality to not apply weight decay
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
+ # mask boolean with the same structure as the parameters.
+ # The mask is True for parameters that should be decayed.
+ def decay_mask_fn(params):
+ flat_params = traverse_util.flatten_dict(params)
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
+ return traverse_util.unflatten_dict(flat_mask)
+
+ # create adam optimizer
+ if training_args.adafactor:
+ # We use the default parameters here to initialize adafactor,
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
+ optimizer = optax.adafactor(
+ learning_rate=linear_decay_lr_schedule_fn,
+ )
+ else:
+ optimizer = optax.adamw(
+ learning_rate=linear_decay_lr_schedule_fn,
+ b1=training_args.adam_beta1,
+ b2=training_args.adam_beta2,
+ weight_decay=training_args.weight_decay,
+ mask=decay_mask_fn,
+ )
+
+ # Setup train state
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
+
+ # Define gradient update step fn
+ def train_step(state, batch, dropout_rng):
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
+
+ def loss_fn(params):
+ labels = batch.pop("labels")
+
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
+
+ # compute loss, ignore padded input tokens and special tokens
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
+
+ # take average
+ loss = loss.sum() / label_mask.sum()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(loss_fn)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ metrics = jax.lax.pmean(
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
+ )
+
+ return new_state, metrics, new_dropout_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Define eval fn
+ def eval_step(params, batch):
+ labels = batch.pop("labels")
+
+ logits = model(**batch, params=params, train=False)[0]
+
+ # compute loss, ignore padded input tokens and special tokens
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
+
+ # compute accuracy
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
+
+ # summarize metrics
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
+ metrics = jax.lax.psum(metrics, axis_name="batch")
+
+ return metrics
+
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+
+ train_time = 0
+ epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+ train_start = time.time()
+ train_metrics = []
+
+ # Create sampling rng
+ rng, input_rng = jax.random.split(rng)
+
+ # Generate an epoch by shuffling sampling indices from the train dataset
+ num_train_samples = len(tokenized_datasets["train"])
+ # Avoid using jax.numpy here in case of TPU training
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
+
+ # Gather the indexes for creating the batch and do a training step
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
+ model_inputs = data_collator(samples)
+
+ # Model forward
+ model_inputs = shard(model_inputs.data)
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
+ train_metrics.append(train_metric)
+
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
+
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
+ # Save metrics
+ train_metric = jax_utils.unreplicate(train_metric)
+ train_time += time.time() - train_start
+ if has_tensorboard and jax.process_index() == 0:
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
+
+ epochs.write(
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
+ )
+
+ train_metrics = []
+
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
+ # ======================== Evaluating ==============================
+ num_eval_samples = len(tokenized_datasets["validation"])
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
+
+ eval_metrics = []
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
+ model_inputs = data_collator(samples)
+
+ # Model forward
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
+ eval_metrics.append(metrics)
+
+ # normalize eval metrics
+ eval_metrics = get_metrics(eval_metrics)
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
+ eval_normalizer = eval_metrics.pop("normalizer")
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
+
+ # Update progress bar
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
+
+ # Save metrics
+ if has_tensorboard and jax.process_index() == 0:
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
+
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
+ # save checkpoint after each epoch and push checkpoint to the hub
+ if jax.process_index() == 0:
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
+ model.save_pretrained(training_args.output_dir, params=params)
+ tokenizer.save_pretrained(training_args.output_dir)
+ if training_args.push_to_hub:
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
+
+ # Eval after training
+ if training_args.do_eval:
+ num_eval_samples = len(tokenized_datasets["validation"])
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
+
+ eval_metrics = []
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
+ model_inputs = data_collator(samples)
+
+ # Model forward
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
+ eval_metrics.append(metrics)
+
+ # normalize eval metrics
+ eval_metrics = get_metrics(eval_metrics)
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
+ eval_normalizer = eval_metrics.pop("normalizer")
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
+
+ try:
+ perplexity = math.exp(eval_metrics["loss"])
+ except OverflowError:
+ perplexity = float("inf")
+ eval_metrics["perplexity"] = perplexity
+
+ if jax.process_index() == 0:
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
+ path = os.path.join(training_args.output_dir, "eval_results.json")
+ with open(path, "w") as f:
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py
index afb6d75b3857..1a0428fdd670 100755
--- a/examples/flax/language-modeling/run_clm_flax.py
+++ b/examples/flax/language-modeling/run_clm_flax.py
@@ -43,7 +43,7 @@
import optax
import transformers
from flax import jax_utils, traverse_util
-from flax.jax_utils import unreplicate
+from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
@@ -58,7 +58,7 @@
set_seed,
)
from transformers.testing_utils import CaptureLogger
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -138,8 +138,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -162,14 +163,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -194,15 +200,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -217,9 +227,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -252,20 +264,24 @@ def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
-def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
+def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
- Shuffle batches if `shuffle` is `True`.
+ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
+ and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
- steps_per_epoch = len(dataset) // batch_size
-
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
+ batch_idx = np.asarray(batch_idx)
else:
- batch_idx = jnp.arange(len(dataset))
+ batch_idx = np.arange(len(dataset))
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
+ if drop_last:
+ steps_per_epoch = len(dataset) // batch_size
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
+ else:
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
@@ -316,6 +332,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -485,7 +505,6 @@ def main():
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
- use_auth_token=True if model_args.use_auth_token else None,
)
# Preprocessing the datasets.
@@ -505,7 +524,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
@@ -605,7 +625,8 @@ def group_texts(examples):
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
@@ -622,15 +643,19 @@ def group_texts(examples):
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
- # Note that this mask is specifically adapted for FlaxGPT2.
- # For other models, one should correct the layer norm parameter naming
- # accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {
- path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
- for path in flat_params
- }
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
@@ -735,7 +760,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
@@ -743,13 +769,14 @@ def eval_step(params, batch):
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
eval_metrics = []
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
- eval_steps = len(eval_dataset) // eval_batch_size
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
+ eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
- batch = shard(batch)
- metrics = p_eval_step(state.params, batch)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, batch, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# normalize eval metrics
@@ -762,7 +789,10 @@ def eval_step(params, batch):
eval_metrics["perplexity"] = float("inf")
# Print metrics and update progress bar
- desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
+ desc = (
+ f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
+ f" {eval_metrics['perplexity']})"
+ )
epochs.write(desc)
epochs.desc = desc
@@ -782,12 +812,14 @@ def eval_step(params, batch):
# Eval after training
if training_args.do_eval:
eval_metrics = []
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
- eval_steps = len(eval_dataset) // eval_batch_size
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
+ eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
- batch = shard(next(eval_loader))
- metrics = p_eval_step(state.params, batch)
+ batch = next(eval_loader)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, batch, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# normalize eval metrics
diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py
index 6ea0f6e1564f..65f6a2285d9c 100755
--- a/examples/flax/language-modeling/run_mlm_flax.py
+++ b/examples/flax/language-modeling/run_mlm_flax.py
@@ -43,6 +43,7 @@
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
+from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -58,7 +59,7 @@
is_tensorboard_available,
set_seed,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@@ -136,8 +137,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -160,14 +162,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -209,8 +216,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -223,8 +232,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
line_by_line: bool = field(
@@ -316,15 +327,20 @@ def mask_tokens(
return inputs, labels
-def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
- samples_to_remove = num_samples % batch_size
-
- if samples_to_remove != 0:
- samples_idx = samples_idx[:-samples_to_remove]
- sections_split = num_samples // batch_size
- batch_idx = np.split(samples_idx, sections_split)
- return batch_idx
+ if drop_last:
+ samples_to_remove = num_samples % batch_size
+ if samples_to_remove != 0:
+ samples_idx = samples_idx[:-samples_to_remove]
+ sections_split = num_samples // batch_size
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
+ else:
+ sections_split = math.ceil(num_samples / batch_size)
+ samples_idx = np.array_split(samples_idx, sections_split)
+ return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
@@ -355,6 +371,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -618,13 +638,13 @@ def group_texts(examples):
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
- use_auth_token=True if model_args.use_auth_token else None,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
@@ -645,12 +665,19 @@ def group_texts(examples):
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
- # Note that this mask is specifically adapted for FlaxBERT-like models.
- # For other models, one should correct the layer norm parameter naming
- # accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
@@ -741,7 +768,8 @@ def eval_step(params, batch):
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
- train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
+ # Avoid using jax.numpy here in case of TPU training
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
@@ -764,7 +792,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -772,8 +801,9 @@ def eval_step(params, batch):
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
- eval_samples_idx = jnp.arange(num_eval_samples)
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
@@ -781,8 +811,9 @@ def eval_step(params, batch):
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
- model_inputs = shard(model_inputs.data)
- metrics = p_eval_step(state.params, model_inputs)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# normalize eval metrics
@@ -810,8 +841,9 @@ def eval_step(params, batch):
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
- eval_samples_idx = jnp.arange(num_eval_samples)
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
@@ -819,8 +851,9 @@ def eval_step(params, batch):
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
- model_inputs = shard(model_inputs.data)
- metrics = p_eval_step(state.params, model_inputs)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# normalize eval metrics
diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py
index 368ecf0e61c0..0030fc8da66a 100755
--- a/examples/flax/language-modeling/run_t5_mlm_flax.py
+++ b/examples/flax/language-modeling/run_t5_mlm_flax.py
@@ -21,6 +21,7 @@
"""
import json
import logging
+import math
import os
import sys
import time
@@ -41,6 +42,7 @@
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
+from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -57,7 +59,7 @@
set_seed,
)
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@@ -135,8 +137,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -159,14 +162,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -208,7 +216,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization and masking. Sequences longer than this"
+ " will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -316,7 +327,7 @@ class FlaxDataCollatorForT5MLM:
pad_token_id: int
decoder_start_token_id: int
- def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
# convert list to dict and tensorize input
batch = BatchEncoding(
@@ -337,12 +348,14 @@ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarra
if batch["input_ids"].shape[-1] != self.input_length:
raise ValueError(
- f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
+ f" should be {self.target_length}."
)
if batch["labels"].shape[-1] != self.target_length:
raise ValueError(
- f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
+ f" {self.target_length}."
)
# to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
@@ -448,15 +461,20 @@ def _random_segmentation(num_items, num_segments):
return is_noise[:orig_length]
-def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
- samples_to_remove = num_samples % batch_size
-
- if samples_to_remove != 0:
- samples_idx = samples_idx[:-samples_to_remove]
- sections_split = num_samples // batch_size
- batch_idx = np.split(samples_idx, sections_split)
- return batch_idx
+ if drop_last:
+ samples_to_remove = num_samples % batch_size
+ if samples_to_remove != 0:
+ samples_idx = samples_idx[:-samples_to_remove]
+ sections_split = num_samples // batch_size
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
+ else:
+ sections_split = math.ceil(num_samples / batch_size)
+ samples_idx = np.array_split(samples_idx, sections_split)
+ return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
@@ -487,6 +505,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -724,7 +746,6 @@ def group_texts(examples):
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
- use_auth_token=True if model_args.use_auth_token else None,
)
# Data collator
@@ -742,7 +763,8 @@ def group_texts(examples):
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
@@ -768,10 +790,17 @@ def group_texts(examples):
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {
- path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
- for path in flat_params
- }
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
@@ -856,6 +885,7 @@ def eval_step(params, batch):
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
+ # Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
@@ -884,7 +914,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
@@ -892,8 +923,9 @@ def eval_step(params, batch):
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
- eval_samples_idx = jnp.arange(num_eval_samples)
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
@@ -901,8 +933,9 @@ def eval_step(params, batch):
model_inputs = data_collator(samples)
# Model forward
- model_inputs = shard(model_inputs.data)
- metrics = p_eval_step(state.params, model_inputs)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# get eval metrics
@@ -928,8 +961,9 @@ def eval_step(params, batch):
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
- eval_samples_idx = jnp.arange(num_eval_samples)
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
@@ -937,8 +971,9 @@ def eval_step(params, batch):
model_inputs = data_collator(samples)
# Model forward
- model_inputs = shard(model_inputs.data)
- metrics = p_eval_step(state.params, model_inputs)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# get eval metrics
diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py
index ac4ec706bfcf..1b951e358398 100644
--- a/examples/flax/question-answering/run_qa.py
+++ b/examples/flax/question-answering/run_qa.py
@@ -20,27 +20,28 @@
import json
import logging
+import math
import os
import random
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
-from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from tqdm import tqdm
+import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import struct, traverse_util
-from flax.jax_utils import replicate, unreplicate
+from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -53,14 +54,14 @@
PreTrainedTokenizerFast,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from utils_qa import postprocess_qa_predictions
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
@@ -157,14 +158,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -200,37 +206,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -239,9 +254,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -255,8 +272,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -309,12 +328,19 @@ class TrainState(train_state.TrainState):
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
- # Note that this mask is specifically adapted for FlaxBERT-like models.
- # For other models, one should correct the layer norm parameter naming
- # accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(
@@ -381,11 +407,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
# region eval data iterator
def eval_data_collator(dataset: Dataset, batch_size: int):
- """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
- for i in range(len(dataset) // batch_size):
- batch = dataset[i * batch_size : (i + 1) * batch_size]
+ """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
+ batch_idx = np.arange(len(dataset))
+
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
+
+ for idx in batch_idx:
+ batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
- batch = shard(batch)
yield batch
@@ -406,6 +436,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args, framework="flax")
# endregion
# region Logging
@@ -498,9 +532,9 @@ def main():
# region Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# endregion
@@ -743,7 +777,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
@@ -827,8 +861,9 @@ def write_eval_metric(summary_writer, eval_metrics, step):
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
- train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
- eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.local_device_count()
# endregion
# region Load model
@@ -928,7 +963,8 @@ def eval_step(state, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -945,32 +981,17 @@ def eval_step(state, batch):
# evaluate
for batch in tqdm(
eval_data_collator(eval_dataset, eval_batch_size),
- total=len(eval_dataset) // eval_batch_size,
+ total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
- predictions = p_eval_step(state, batch)
- start_logits = np.array([pred for pred in chain(*predictions[0])])
- end_logits = np.array([pred for pred in chain(*predictions[1])])
- all_start_logits.append(start_logits)
- all_end_logits.append(end_logits)
-
- # evaluate also on leftover examples (not divisible by batch_size)
- num_leftover_samples = len(eval_dataset) % eval_batch_size
-
- # make sure leftover batch is evaluated on one device
- if num_leftover_samples > 0 and jax.process_index() == 0:
- # take leftover samples
- batch = eval_dataset[-num_leftover_samples:]
- batch = {k: np.array(v) for k, v in batch.items()}
- _ = batch.pop("example_id")
- _ = batch.pop("offset_mapping")
-
- predictions = eval_step(unreplicate(state), batch)
- start_logits = np.array([pred for pred in predictions[0]])
- end_logits = np.array([pred for pred in predictions[1]])
+ predictions = pad_shard_unpad(p_eval_step)(
+ state, batch, min_device_batch=per_device_eval_batch_size
+ )
+ start_logits = np.array(predictions[0])
+ end_logits = np.array(predictions[1])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
@@ -1009,30 +1030,15 @@ def eval_step(state, batch):
all_start_logits = []
all_end_logits = []
- eva_loader = eval_data_collator(eval_dataset, eval_batch_size)
- for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
- _ = batch.pop("example_id")
- _ = batch.pop("offset_mapping")
- predictions = p_eval_step(state, batch)
- start_logits = np.array([pred for pred in chain(*predictions[0])])
- end_logits = np.array([pred for pred in chain(*predictions[1])])
- all_start_logits.append(start_logits)
- all_end_logits.append(end_logits)
-
- # evaluate also on leftover examples (not divisible by batch_size)
- num_leftover_samples = len(eval_dataset) % eval_batch_size
-
- # make sure leftover batch is evaluated on one device
- if num_leftover_samples > 0 and jax.process_index() == 0:
- # take leftover samples
- batch = eval_dataset[-num_leftover_samples:]
- batch = {k: np.array(v) for k, v in batch.items()}
+ eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
+ for batch in tqdm(
+ eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2
+ ):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
-
- predictions = eval_step(unreplicate(state), batch)
- start_logits = np.array([pred for pred in predictions[0]])
- end_logits = np.array([pred for pred in predictions[1]])
+ predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
+ start_logits = np.array(predictions[0])
+ end_logits = np.array(predictions[1])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
diff --git a/examples/flax/summarization/requirements.txt b/examples/flax/summarization/requirements.txt
index 7507ae1b69c9..58c7c26af78a 100644
--- a/examples/flax/summarization/requirements.txt
+++ b/examples/flax/summarization/requirements.txt
@@ -3,3 +3,4 @@ jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
+evaluate>=0.2.0
diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py
index 3ebff73b98ff..c193fe0bc374 100644
--- a/examples/flax/summarization/run_summarization_flax.py
+++ b/examples/flax/summarization/run_summarization_flax.py
@@ -20,6 +20,7 @@
import json
import logging
+import math
import os
import sys
import time
@@ -32,16 +33,17 @@
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
-from datasets import Dataset, load_dataset, load_metric
+from datasets import Dataset, load_dataset
from tqdm import tqdm
+import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from filelock import FileLock
from flax import jax_utils, traverse_util
-from flax.jax_utils import unreplicate
+from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
@@ -54,7 +56,7 @@
HfArgumentParser,
is_tensorboard_available,
)
-from transformers.utils import get_full_repo_name, is_offline_mode
+from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -149,8 +151,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -173,14 +176,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -217,45 +225,57 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
- "during evaluation."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
+ "during evaluation."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -271,8 +291,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
- "which is used during evaluation."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
+ "which is used during evaluation."
+ )
},
)
overwrite_cache: bool = field(
@@ -315,26 +337,28 @@ def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
-def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
+def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
- Shuffle batches if `shuffle` is `True`.
+ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
+ and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
- steps_per_epoch = len(dataset) // batch_size
-
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
+ batch_idx = np.asarray(batch_idx)
else:
- batch_idx = jnp.arange(len(dataset))
+ batch_idx = np.arange(len(dataset))
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
+ if drop_last:
+ steps_per_epoch = len(dataset) // batch_size
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
+ else:
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
- batch = {k: jnp.array(v) for k, v in batch.items()}
-
- batch = shard(batch)
+ batch = {k: np.array(v) for k, v in batch.items()}
yield batch
@@ -379,6 +403,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -505,7 +533,6 @@ def main():
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
- use_auth_token=True if model_args.use_auth_token else None,
)
if model.config.decoder_start_token_id is None:
@@ -563,10 +590,13 @@ def preprocess_function(examples):
)
# Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
- )
+ labels = tokenizer(
+ text_target=targets,
+ max_length=max_target_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="np",
+ )
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
@@ -630,7 +660,7 @@ def preprocess_function(examples):
)
# Metric
- metric = load_metric("rouge")
+ metric = evaluate.load("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
@@ -650,12 +680,9 @@ def compute_metrics(preds, labels):
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
- # Extract a few results from ROUGE
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
-
+ result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
- result = {k: round(v, 4) for k, v in result.items()}
return result
# Enable tensorboard only on the master node
@@ -683,7 +710,8 @@ def compute_metrics(preds, labels):
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
@@ -700,15 +728,19 @@ def compute_metrics(preds, labels):
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
- # Note that this mask is specifically adapted for FlaxBart.
- # For FlaxT5, one should correct the layer norm parameter naming
- # accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- layer_norm_params = [
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
- ]
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
@@ -823,6 +855,7 @@ def generate_step(params, batch):
# train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
+ batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
@@ -831,7 +864,8 @@ def generate_step(params, batch):
train_metric = unreplicate(train_metric)
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
@@ -839,21 +873,23 @@ def generate_step(params, batch):
eval_preds = []
eval_labels = []
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
- eval_steps = len(eval_dataset) // eval_batch_size
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
+ eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
labels = batch["labels"]
- metrics = p_eval_step(state.params, batch)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, batch, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
# generation
if data_args.predict_with_generate:
- generated_ids = p_generate_step(state.params, batch)
+ generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
+ eval_labels.extend(labels)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
@@ -892,21 +928,23 @@ def generate_step(params, batch):
pred_generations = []
pred_labels = []
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
- pred_steps = len(predict_dataset) // eval_batch_size
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
+ pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
- metrics = p_eval_step(state.params, batch)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, batch, min_device_batch=per_device_eval_batch_size
+ )
pred_metrics.append(metrics)
# generation
if data_args.predict_with_generate:
- generated_ids = p_generate_step(state.params, batch)
+ generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
+ pred_labels.extend(labels)
# normalize prediction metrics
pred_metrics = get_metrics(pred_metrics)
diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py
index 23144069d7dd..e0dfab2f52e9 100755
--- a/examples/flax/text-classification/run_flax_glue.py
+++ b/examples/flax/text-classification/run_flax_glue.py
@@ -16,26 +16,27 @@
""" Finetuning a š¤ Flax Transformers model for sequence classification on GLUE."""
import json
import logging
+import math
import os
import random
import sys
import time
from dataclasses import dataclass, field
-from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from tqdm import tqdm
+import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import struct, traverse_util
-from flax.jax_utils import replicate, unreplicate
+from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -48,12 +49,12 @@
TrainingArguments,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
@@ -103,8 +104,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -148,29 +151,37 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
@@ -216,7 +227,17 @@ class TrainState(train_state.TrainState):
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(
@@ -280,11 +301,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
def glue_eval_data_collator(dataset: Dataset, batch_size: int):
- """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
- for i in range(len(dataset) // batch_size):
- batch = dataset[i * batch_size : (i + 1) * batch_size]
+ """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
+ batch_idx = np.arange(len(dataset))
+
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
+
+ for idx in batch_idx:
+ batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
- batch = shard(batch)
yield batch
@@ -298,6 +323,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args, framework="flax")
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -497,8 +526,9 @@ def write_eval_metric(summary_writer, eval_metrics, step):
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
- train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
- eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset),
@@ -541,9 +571,9 @@ def eval_step(state, batch):
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None:
- metric = load_metric("glue", data_args.task_name)
+ metric = evaluate.load("glue", data_args.task_name)
else:
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
@@ -585,7 +615,8 @@ def eval_step(state, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -596,26 +627,15 @@ def eval_step(state, batch):
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(
eval_loader,
- total=len(eval_dataset) // eval_batch_size,
+ total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
- predictions = p_eval_step(state, batch)
- metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
-
- # evaluate also on leftover examples (not divisible by batch_size)
- num_leftover_samples = len(eval_dataset) % eval_batch_size
-
- # make sure leftover batch is evaluated on one device
- if num_leftover_samples > 0 and jax.process_index() == 0:
- # take leftover samples
- batch = eval_dataset[-num_leftover_samples:]
- batch = {k: np.array(v) for k, v in batch.items()}
-
- labels = batch.pop("labels")
- predictions = eval_step(unreplicate(state), batch)
- metric.add_batch(predictions=predictions, references=labels)
+ predictions = pad_shard_unpad(p_eval_step)(
+ state, batch, min_device_batch=per_device_eval_batch_size
+ )
+ metric.add_batch(predictions=np.array(predictions), references=labels)
eval_metric = metric.compute()
diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py
index a0e01b080275..ad68c0997fed 100644
--- a/examples/flax/token-classification/run_flax_ner.py
+++ b/examples/flax/token-classification/run_flax_ner.py
@@ -16,6 +16,7 @@
""" Fine-tuning a š¤ Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
import json
import logging
+import math
import os
import random
import sys
@@ -28,15 +29,16 @@
import datasets
import numpy as np
-from datasets import ClassLabel, load_dataset, load_metric
+from datasets import ClassLabel, load_dataset
from tqdm import tqdm
+import evaluate
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import struct, traverse_util
-from flax.jax_utils import replicate, unreplicate
+from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -47,13 +49,13 @@
HfArgumentParser,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
@@ -150,8 +152,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -196,36 +200,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -272,12 +286,19 @@ class TrainState(train_state.TrainState):
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
- # Note that this mask is specifically adapted for FlaxBERT-like models.
- # For other models, one should correct the layer norm parameter naming
- # accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
- flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
+ # find out all LayerNorm parameters
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
+ layer_norm_named_params = set(
+ [
+ layer[-2:]
+ for layer_norm_name in layer_norm_candidates
+ for layer in flat_params.keys()
+ if layer_norm_name in "".join(layer).lower()
+ ]
+ )
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(
@@ -332,11 +353,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
def eval_data_collator(dataset: Dataset, batch_size: int):
- """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
- for i in range(len(dataset) // batch_size):
- batch = dataset[i * batch_size : (i + 1) * batch_size]
+ """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
+ batch_idx = np.arange(len(dataset))
+
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
+
+ for idx in batch_idx:
+ batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
- batch = shard(batch)
yield batch
@@ -354,6 +379,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args, framework="flax")
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -577,6 +606,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn(
@@ -617,7 +647,7 @@ def eval_step(state, batch):
p_eval_step = jax.pmap(eval_step, axis_name="batch")
- metric = load_metric("seqeval")
+ metric = evaluate.load("seqeval")
def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
@@ -693,7 +723,8 @@ def compute_metrics():
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -704,34 +735,16 @@ def compute_metrics():
# evaluate
for batch in tqdm(
eval_data_collator(eval_dataset, eval_batch_size),
- total=len(eval_dataset) // eval_batch_size,
+ total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
- predictions = p_eval_step(state, batch)
- predictions = np.array([pred for pred in chain(*predictions)])
- labels = np.array([label for label in chain(*labels)])
- labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
- preds, refs = get_labels(predictions, labels)
- metric.add_batch(
- predictions=preds,
- references=refs,
+ predictions = pad_shard_unpad(p_eval_step)(
+ state, batch, min_device_batch=per_device_eval_batch_size
)
-
- # evaluate also on leftover examples (not divisible by batch_size)
- num_leftover_samples = len(eval_dataset) % eval_batch_size
-
- # make sure leftover batch is evaluated on one device
- if num_leftover_samples > 0 and jax.process_index() == 0:
- # take leftover samples
- batch = eval_dataset[-num_leftover_samples:]
- batch = {k: np.array(v) for k, v in batch.items()}
-
- labels = batch.pop("labels")
- predictions = eval_step(unreplicate(state), batch)
- labels = np.array(labels)
- labels[np.array(batch["attention_mask"]) == 0] = -100
+ predictions = np.array(predictions)
+ labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(
predictions=preds,
@@ -744,7 +757,8 @@ def compute_metrics():
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
else:
logger.info(
- f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
+ f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:"
+ f" {eval_metrics['accuracy']})"
)
if has_tensorboard and jax.process_index() == 0:
@@ -766,28 +780,12 @@ def compute_metrics():
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
labels = batch.pop("labels")
- predictions = p_eval_step(state, batch)
- predictions = np.array([pred for pred in chain(*predictions)])
- labels = np.array([label for label in chain(*labels)])
+ predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
+ predictions = np.array(predictions)
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
- # evaluate also on leftover examples (not divisible by batch_size)
- num_leftover_samples = len(eval_dataset) % eval_batch_size
-
- # make sure leftover batch is evaluated on one device
- if num_leftover_samples > 0 and jax.process_index() == 0:
- # take leftover samples
- batch = eval_dataset[-num_leftover_samples:]
- batch = {k: np.array(v) for k, v in batch.items()}
-
- labels = np.array(batch.pop("labels"))
- predictions = eval_step(unreplicate(state), batch)
- labels[np.array(batch["attention_mask"]) == 0] = -100
- preds, refs = get_labels(predictions, labels)
- metric.add_batch(predictions=preds, references=refs)
-
eval_metrics = compute_metrics()
if jax.process_index() == 0:
diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py
index 0dc7b2f95742..3de3c977ab1d 100644
--- a/examples/flax/vision/run_image_classification.py
+++ b/examples/flax/vision/run_image_classification.py
@@ -40,7 +40,7 @@
import optax
import transformers
from flax import jax_utils
-from flax.jax_utils import unreplicate
+from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
@@ -53,7 +53,7 @@
is_tensorboard_available,
set_seed,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -134,8 +134,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -151,14 +152,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -179,15 +185,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -246,6 +256,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -349,13 +363,13 @@ def main():
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
- use_auth_token=True if model_args.use_auth_token else None,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
@@ -385,7 +399,7 @@ def collate_fn(examples):
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
- drop_last=True,
+ drop_last=False,
collate_fn=collate_fn,
)
@@ -509,7 +523,8 @@ def eval_step(params, batch):
train_step_progress_bar.close()
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
@@ -518,8 +533,9 @@ def eval_step(params, batch):
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
- batch = shard(batch)
- metrics = p_eval_step(state.params, batch)
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
+ state.params, batch, min_device_batch=per_device_eval_batch_size
+ )
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)
diff --git a/examples/legacy/multiple_choice/run_multiple_choice.py b/examples/legacy/multiple_choice/run_multiple_choice.py
index aeb9b9dc434a..d8007da6cb67 100644
--- a/examples/legacy/multiple_choice/run_multiple_choice.py
+++ b/examples/legacy/multiple_choice/run_multiple_choice.py
@@ -78,8 +78,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -102,7 +104,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/multiple_choice/utils_multiple_choice.py b/examples/legacy/multiple_choice/utils_multiple_choice.py
index 2b6b5cc18322..3dbc3689cc48 100644
--- a/examples/legacy/multiple_choice/utils_multiple_choice.py
+++ b/examples/legacy/multiple_choice/utils_multiple_choice.py
@@ -182,7 +182,7 @@ def __init__(
)
def gen():
- for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
+ for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
@@ -297,7 +297,7 @@ def _read_txt(self, input_dir):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
- for (_, data_raw) in enumerate(lines):
+ for _, data_raw in enumerate(lines):
race_id = "%s-%s" % (set_type, data_raw["race_id"])
article = data_raw["article"]
for i in range(len(data_raw["answers"])):
@@ -518,7 +518,7 @@ def convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
+ for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
choices_inputs = []
diff --git a/examples/legacy/pytorch-lightning/lightning_base.py b/examples/legacy/pytorch-lightning/lightning_base.py
index b7f53076e3bc..b3104a25a8b1 100644
--- a/examples/legacy/pytorch-lightning/lightning_base.py
+++ b/examples/legacy/pytorch-lightning/lightning_base.py
@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/legacy/pytorch-lightning/run_glue.py b/examples/legacy/pytorch-lightning/run_glue.py
index abb06bf526bb..63b58bcf413c 100644
--- a/examples/legacy/pytorch-lightning/run_glue.py
+++ b/examples/legacy/pytorch-lightning/run_glue.py
@@ -148,8 +148,10 @@ def add_model_specific_args(parser, root_dir):
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
diff --git a/examples/legacy/pytorch-lightning/run_ner.py b/examples/legacy/pytorch-lightning/run_ner.py
index 1066c6fed48c..b1bdd125c22e 100644
--- a/examples/legacy/pytorch-lightning/run_ner.py
+++ b/examples/legacy/pytorch-lightning/run_ner.py
@@ -173,8 +173,10 @@ def add_model_specific_args(parser, root_dir):
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
diff --git a/examples/legacy/pytorch-lightning/run_ner.sh b/examples/legacy/pytorch-lightning/run_ner.sh
index 2913473eb8cd..a5b185aa960d 100755
--- a/examples/legacy/pytorch-lightning/run_ner.sh
+++ b/examples/legacy/pytorch-lightning/run_ner.sh
@@ -5,7 +5,7 @@ pip install -r ../requirements.txt
## The relevant files are currently on a shared Google
## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J
-## Monitor for changes and eventually migrate to nlp dataset
+## Monitor for changes and eventually migrate to use the `datasets` library
curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \
diff --git a/examples/legacy/question-answering/run_squad.py b/examples/legacy/question-answering/run_squad.py
index fbf2ebd6351a..674e7a9accbf 100644
--- a/examples/legacy/question-answering/run_squad.py
+++ b/examples/legacy/question-answering/run_squad.py
@@ -551,8 +551,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -564,8 +566,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -610,20 +614,27 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument(
"--lang_id",
default=0,
type=int,
- help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
+ help=(
+ "language id of input for language-specific xlm models (see"
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
+ ),
)
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
@@ -652,8 +663,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/legacy/question-answering/run_squad_trainer.py b/examples/legacy/question-answering/run_squad_trainer.py
index 7089326372ea..314b140e828c 100644
--- a/examples/legacy/question-answering/run_squad_trainer.py
+++ b/examples/legacy/question-answering/run_squad_trainer.py
@@ -84,7 +84,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/run_language_modeling.py b/examples/legacy/run_language_modeling.py
index 12b62f5d816c..59490f710e13 100755
--- a/examples/legacy/run_language_modeling.py
+++ b/examples/legacy/run_language_modeling.py
@@ -68,7 +68,10 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization. Leave None if you want to train a model from"
+ " scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -99,8 +102,10 @@ class DataTrainingArguments:
train_data_files: Optional[str] = field(
default=None,
metadata={
- "help": "The input training data files (multiple files in glob format). "
- "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
+ "help": (
+ "The input training data files (multiple files in glob format). "
+ "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
+ )
},
)
eval_data_file: Optional[str] = field(
@@ -130,7 +135,10 @@ class DataTrainingArguments:
plm_probability: float = field(
default=1 / 6,
metadata={
- "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
+ "help": (
+ "Ratio of length of a span of masked tokens to surrounding context length for permutation language"
+ " modeling."
+ )
},
)
max_span_length: int = field(
@@ -140,9 +148,11 @@ class DataTrainingArguments:
block_size: int = field(
default=-1,
metadata={
- "help": "Optional input sequence length after tokenization."
- "The training dataset will be truncated in block of this size for training."
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization."
+ "The training dataset will be truncated in block of this size for training."
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -206,7 +216,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
@@ -253,8 +264,8 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
raise ValueError(
- "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
- "and load it from here, using --tokenizer_name"
+ "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another"
+ " script, save it,and load it from here, using --tokenizer_name"
)
if model_args.model_name_or_path:
diff --git a/examples/legacy/run_openai_gpt.py b/examples/legacy/run_openai_gpt.py
index 2af3e267d2e7..1f02570f8f51 100755
--- a/examples/legacy/run_openai_gpt.py
+++ b/examples/legacy/run_openai_gpt.py
@@ -126,15 +126,15 @@ def main():
"--max_steps",
default=-1,
type=int,
- help="If > 0: set total number of training \
- steps to perform. Override num_train_epochs.",
+ help=(
+ "If > 0: set total number of training steps to perform. Override num_train_epochs."
+ ),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
- help="Number of updates steps to accumulate before\
- performing a backward/update pass.",
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
diff --git a/examples/legacy/run_swag.py b/examples/legacy/run_swag.py
index e7760410892f..5cac1567243c 100755
--- a/examples/legacy/run_swag.py
+++ b/examples/legacy/run_swag.py
@@ -516,8 +516,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -576,8 +578,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/legacy/seq2seq/finetune_trainer.py b/examples/legacy/seq2seq/finetune_trainer.py
index 3efc8f90f25b..f174f7fb5018 100755
--- a/examples/legacy/seq2seq/finetune_trainer.py
+++ b/examples/legacy/seq2seq/finetune_trainer.py
@@ -90,31 +90,39 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=142,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. "
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. "
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
- "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for test target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
diff --git a/examples/legacy/seq2seq/old_test_calculate_rouge.py b/examples/legacy/seq2seq/old_test_calculate_rouge.py
index bd1dd57a2725..17b87cb481a6 100644
--- a/examples/legacy/seq2seq/old_test_calculate_rouge.py
+++ b/examples/legacy/seq2seq/old_test_calculate_rouge.py
@@ -22,15 +22,30 @@
PRED = [
- 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
- "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
- "Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
+ 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
+ ' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
+ " depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
+ "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
+ " accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
+ " founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
+ " body.",
+ "Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
+ " state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
+ " world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
+ " punishment.",
]
TGT = [
- 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
- "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
- "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
+ 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
+ ' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
+ " had informed his Lufthansa training school of an episode of severe depression, airline says .",
+ "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
+ " Israel and the United States opposed the move, which could open the door to war crimes investigations against"
+ " Israelis .",
+ "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
+ " death . Organization claims that governments around the world are using the threat of terrorism to advance"
+ " executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
+ " sentences up by 28% .",
]
@@ -65,7 +80,8 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
]
tgt = [
"Margot Frank, died in 1945, a month earlier than previously thought.",
- 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
+ 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
+ " the final seconds on board Flight 9525.",
]
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
diff --git a/examples/legacy/seq2seq/run_eval.py b/examples/legacy/seq2seq/run_eval.py
index e21f57c1c609..a8aa8e7ef95d 100755
--- a/examples/legacy/seq2seq/run_eval.py
+++ b/examples/legacy/seq2seq/run_eval.py
@@ -121,7 +121,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
- help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
+ help=(
+ "use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
+ " lang=en-ru. If no value is passed, the current datetime string will be used."
+ ),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
diff --git a/examples/legacy/seq2seq/run_eval_search.py b/examples/legacy/seq2seq/run_eval_search.py
index f7b3bda0f54f..e1a0c8660c9b 100755
--- a/examples/legacy/seq2seq/run_eval_search.py
+++ b/examples/legacy/seq2seq/run_eval_search.py
@@ -35,7 +35,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
- sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
+ sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
@@ -66,7 +66,10 @@ def run_search():
prog = sys.argv[0]
parser = argparse.ArgumentParser(
- usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
+ usage=(
+ "\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
+ " refer to `run_eval.py -h` for the complete list."
+ )
)
parser.add_argument(
"--search",
@@ -83,7 +86,10 @@ def run_search():
nargs="?",
type=str,
const=datetime_now(),
- help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
+ help=(
+ "add custom notes to be printed before the results table. If no value is passed, the current datetime"
+ " string will be used."
+ ),
)
args, args_main = parser.parse_known_args()
# we share some of the args
diff --git a/examples/legacy/seq2seq/seq2seq_trainer.py b/examples/legacy/seq2seq/seq2seq_trainer.py
index eeff082499c4..dbf12725f2db 100644
--- a/examples/legacy/seq2seq/seq2seq_trainer.py
+++ b/examples/legacy/seq2seq/seq2seq_trainer.py
@@ -57,9 +57,10 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if config is None:
- assert isinstance(
- self.model, PreTrainedModel
- ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
+ assert isinstance(self.model, PreTrainedModel), (
+ "If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
+ f" {self.model.__class__}"
+ )
self.config = self.model.config
else:
self.config = config
@@ -68,13 +69,15 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
- assert (
- self.config.pad_token_id is not None
- ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
+ assert self.config.pad_token_id is not None, (
+ "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
+ " calculation or doing label smoothing."
+ )
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
logger.warning(
- f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
+ f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
+ " padding.."
)
if self.args.label_smoothing == 0:
@@ -248,7 +251,8 @@ def _pad_tensors_to_max_len(self, tensor, max_length):
if pad_token_id is None:
raise ValueError(
- f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
+ "Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
+ f" padded to `max_length`={max_length}"
)
padded_tensor = pad_token_id * torch.ones(
diff --git a/examples/legacy/seq2seq/xla_spawn.py b/examples/legacy/seq2seq/xla_spawn.py
index d84b41994564..5df6bfa2d5dc 100644
--- a/examples/legacy/seq2seq/xla_spawn.py
+++ b/examples/legacy/seq2seq/xla_spawn.py
@@ -39,9 +39,7 @@ def parse_args():
"""
parser = ArgumentParser(
description=(
- "PyTorch TPU distributed training launch "
- "helper utility that will spawn up "
- "multiple distributed processes"
+ "PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)
diff --git a/examples/legacy/text-classification/run_tf_text_classification.py b/examples/legacy/text-classification/run_tf_text_classification.py
index 3564775f30dd..1f845db04c04 100755
--- a/examples/legacy/text-classification/run_tf_text_classification.py
+++ b/examples/legacy/text-classification/run_tf_text_classification.py
@@ -168,8 +168,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -215,7 +217,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/token-classification/README.md b/examples/legacy/token-classification/README.md
index cd9c1587032c..c2fa6eec7282 100644
--- a/examples/legacy/token-classification/README.md
+++ b/examples/legacy/token-classification/README.md
@@ -291,4 +291,4 @@ On the test dataset the following results could be achieved:
05/29/2020 23:34:02 - INFO - __main__ - eval_f1 = 0.47440836543753434
```
-WNUTā17 is a very difficult task. Current state-of-the-art results on this dataset can be found [here](http://nlpprogress.com/english/named_entity_recognition.html).
+WNUTā17 is a very difficult task. Current state-of-the-art results on this dataset can be found [here](https://nlpprogress.com/english/named_entity_recognition.html).
diff --git a/examples/legacy/token-classification/run.sh b/examples/legacy/token-classification/run.sh
index f5cbf0d50e02..b5f1e5f83bc7 100755
--- a/examples/legacy/token-classification/run.sh
+++ b/examples/legacy/token-classification/run.sh
@@ -1,6 +1,6 @@
## The relevant files are currently on a shared Google
## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J
-## Monitor for changes and eventually migrate to nlp dataset
+## Monitor for changes and eventually migrate to use the `datasets` library
curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \
diff --git a/examples/legacy/token-classification/run_ner.py b/examples/legacy/token-classification/run_ner.py
index a653ecb91c69..477ccb50fb25 100644
--- a/examples/legacy/token-classification/run_ner.py
+++ b/examples/legacy/token-classification/run_ner.py
@@ -87,8 +87,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -116,7 +118,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
module = import_module("tasks")
diff --git a/examples/legacy/token-classification/run_tf_ner.py b/examples/legacy/token-classification/run_tf_ner.py
index 0169a10f24ac..857d777238f2 100755
--- a/examples/legacy/token-classification/run_tf_ner.py
+++ b/examples/legacy/token-classification/run_tf_ner.py
@@ -88,8 +88,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -111,7 +113,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
module = import_module("tasks")
diff --git a/examples/legacy/token-classification/utils_ner.py b/examples/legacy/token-classification/utils_ner.py
index 2537aecfca6a..35fcb5ef5b7d 100644
--- a/examples/legacy/token-classification/utils_ner.py
+++ b/examples/legacy/token-classification/utils_ner.py
@@ -103,7 +103,7 @@ def convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in enumerate(examples):
+ for ex_index, example in enumerate(examples):
if ex_index % 10_000 == 0:
logger.info("Writing example %d of %d", ex_index, len(examples))
@@ -140,7 +140,7 @@ def convert_examples_to_features(
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
- # used as as the "sentence vector". Note that this only makes sense because
+ # used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens += [sep_token]
label_ids += [pad_token_label_id]
diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md
index 95d42bfc8b38..442511ead93a 100644
--- a/examples/pytorch/README.md
+++ b/examples/pytorch/README.md
@@ -15,12 +15,12 @@ limitations under the License.
# Examples
-This folder contains actively maintained examples of use of š¤ Transformers using the PyTorch backend, organized along NLP tasks.
+This folder contains actively maintained examples of use of š¤ Transformers using the PyTorch backend, organized by ML task.
## The Big Table of Tasks
Here is the list of all our examples:
-- with information on whether they are **built on top of `Trainer``** (if not, they still work, they might
+- with information on whether they are **built on top of `Trainer`** (if not, they still work, they might
just lack some features),
- whether or not they have a version using the [š¤ Accelerate](https://github.com/huggingface/accelerate) library.
- whether or not they leverage the [š¤ Datasets](https://github.com/huggingface/datasets) library.
diff --git a/examples/pytorch/_tests_requirements.txt b/examples/pytorch/_tests_requirements.txt
index 8c13e79aa44b..979890f4b79c 100644
--- a/examples/pytorch/_tests_requirements.txt
+++ b/examples/pytorch/_tests_requirements.txt
@@ -22,3 +22,4 @@ protobuf
torchvision
jiwer
librosa
+evaluate >= 0.2.0
diff --git a/examples/pytorch/audio-classification/README.md b/examples/pytorch/audio-classification/README.md
index 12eb5e6ed399..21da5b9935ca 100644
--- a/examples/pytorch/audio-classification/README.md
+++ b/examples/pytorch/audio-classification/README.md
@@ -18,13 +18,13 @@ limitations under the License.
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
-Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
-*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
-[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
-[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
+Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
+*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
+[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
+[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
very little annotated data to yield good performance on speech classification datasets.
-## Single-GPU
+## Single-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the š£ļø [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
š See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
-## Multi-GPU
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
+## Multi-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for š **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
-| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
+| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |
diff --git a/examples/pytorch/audio-classification/run_audio_classification.py b/examples/pytorch/audio-classification/run_audio_classification.py
index 5ad561ee2b85..9ebd4fb00759 100644
--- a/examples/pytorch/audio-classification/run_audio_classification.py
+++ b/examples/pytorch/audio-classification/run_audio_classification.py
@@ -26,6 +26,7 @@
import numpy as np
from datasets import DatasetDict, load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -37,14 +38,14 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
@@ -86,8 +87,9 @@ class DataTrainingArguments:
eval_split_name: str = field(
default="validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
audio_column_name: str = field(
@@ -100,15 +102,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_length_seconds: float = field(
@@ -149,13 +155,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
@@ -186,6 +198,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_audio_classification", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -300,7 +316,7 @@ def val_transforms(batch):
id2label[str(i)] = label
# Load the accuracy metric from the datasets package
- metric = datasets.load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
# `predictions` and `label_ids` fields) and has to return a dictionary string to float.
@@ -326,6 +342,7 @@ def compute_metrics(eval_pred):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder
diff --git a/examples/pytorch/contrastive-image-text/README.md b/examples/pytorch/contrastive-image-text/README.md
index 714fe36761c5..cfc6a627809f 100644
--- a/examples/pytorch/contrastive-image-text/README.md
+++ b/examples/pytorch/contrastive-image-text/README.md
@@ -43,7 +43,10 @@ cd ..
Having downloaded COCO dataset manually you should be able to load with the `ydshieh/coc_dataset_script` dataset loading script:
```py
-COCO_DIR = "data"
+import os
+import datasets
+
+COCO_DIR = os.path.join(os.getcwd(), "data")
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir=COCO_DIR)
```
@@ -84,7 +87,7 @@ Finally, we can run the example script to train the model:
python examples/pytorch/contrastive-image-text/run_clip.py \
--output_dir ./clip-roberta-finetuned \
--model_name_or_path ./clip-roberta \
- --data_dir ./data \
+ --data_dir $PWD/data \
--dataset_name ydshieh/coco_dataset_script \
--dataset_config_name=2017 \
--image_column image_path \
diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py
index fc036f2a20fa..d3c5355f9d07 100644
--- a/examples/pytorch/contrastive-image-text/run_clip.py
+++ b/examples/pytorch/contrastive-image-text/run_clip.py
@@ -47,14 +47,14 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
@@ -89,8 +89,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_vision_model: bool = field(
@@ -132,22 +134,28 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -225,6 +233,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clip", model_args, data_args)
+
# 2. Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/image-classification/README.md b/examples/pytorch/image-classification/README.md
index 2070c854c769..904981451c6f 100644
--- a/examples/pytorch/image-classification/README.md
+++ b/examples/pytorch/image-classification/README.md
@@ -62,9 +62,11 @@ python run_image_classification.py \
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
### Using your own data
-To use your own dataset, there are 2 ways:
+To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
@@ -177,7 +179,7 @@ the means of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate)
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
diff --git a/examples/pytorch/image-classification/requirements.txt b/examples/pytorch/image-classification/requirements.txt
index a789fee85eef..aadc0e9088f8 100644
--- a/examples/pytorch/image-classification/requirements.txt
+++ b/examples/pytorch/image-classification/requirements.txt
@@ -1,3 +1,3 @@
torch>=1.5.0
torchvision>=0.6.0
-datasets>=1.8.0
\ No newline at end of file
+datasets>=1.17.0
diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py
index ba85814bd784..28000015ab17 100644
--- a/examples/pytorch/image-classification/run_image_classification.py
+++ b/examples/pytorch/image-classification/run_image_classification.py
@@ -19,7 +19,6 @@
from dataclasses import dataclass, field
from typing import Optional
-import datasets
import numpy as np
import torch
from datasets import load_dataset
@@ -34,6 +33,7 @@
ToTensor,
)
+import evaluate
import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
@@ -43,9 +43,10 @@
HfArgumentParser,
Trainer,
TrainingArguments,
+ set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -54,7 +55,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
@@ -93,15 +94,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -140,10 +145,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def collate_fn(examples):
@@ -165,6 +176,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -200,6 +215,9 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
# Initialize our dataset and prepare it for the 'image-classification' task.
if data_args.dataset_name is not None:
dataset = load_dataset(
@@ -238,7 +256,7 @@ def main():
id2label[str(i)] = label
# Load the accuracy metric from the datasets package
- metric = datasets.load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
@@ -263,6 +281,7 @@ def compute_metrics(p):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
index daf67015bfd2..1bd190d1303e 100644
--- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py
+++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
@@ -22,7 +22,7 @@
import datasets
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision.transforms import (
CenterCrop,
@@ -35,6 +35,7 @@
)
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -47,10 +48,13 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
@@ -62,7 +66,10 @@ def parse_args():
"--dataset_name",
type=str,
default="cifar10",
- help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset).",
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset)."
+ ),
)
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
@@ -70,15 +77,19 @@ def parse_args():
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--train_val_split",
@@ -156,7 +167,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -179,9 +205,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
logger.info(accelerator.state)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -271,6 +309,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
@@ -341,17 +380,17 @@ def collate_fn(examples):
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -361,7 +400,10 @@ def collate_fn(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -371,15 +413,18 @@ def collate_fn(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("image_classification_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("image_classification_no_trainer", experiment_config)
# Get the metric function
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -427,17 +472,20 @@ def collate_fn(examples):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -468,19 +516,11 @@ def collate_fn(examples):
break
model.eval()
- samples_seen = 0
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
- predictions, references = accelerator.gather((predictions, batch["labels"]))
- # If we are in a multiprocess environment, the last batch has duplicates
- if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
- predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
- references = references[: len(eval_dataloader.dataset) - samples_seen]
- else:
- samples_seen += references.shape[0]
+ predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
references=references,
@@ -493,10 +533,11 @@ def collate_fn(examples):
accelerator.log(
{
"accuracy": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py
index be65779fe3c8..3ac4106b11ac 100644
--- a/examples/pytorch/image-pretraining/run_mae.py
+++ b/examples/pytorch/image-pretraining/run_mae.py
@@ -34,7 +34,7 @@
ViTMAEForPreTraining,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -43,7 +43,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
@@ -74,15 +74,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -104,8 +108,9 @@ class ModelArguments:
model_name_or_path: str = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
config_name: Optional[str] = field(
@@ -114,8 +119,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
cache_dir: Optional[str] = field(
@@ -129,8 +136,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
mask_ratio: float = field(
@@ -166,6 +175,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mae", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/image-pretraining/run_mim.py b/examples/pytorch/image-pretraining/run_mim.py
index ed39be7a1a15..7626e8be3632 100644
--- a/examples/pytorch/image-pretraining/run_mim.py
+++ b/examples/pytorch/image-pretraining/run_mim.py
@@ -37,7 +37,7 @@
TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -48,7 +48,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
@@ -87,15 +87,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -117,9 +121,11 @@ class ModelArguments:
model_name_or_path: str = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
- "checkpoint identifier on the hub. "
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
+ "checkpoint identifier on the hub. "
+ "Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -132,8 +138,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
cache_dir: Optional[str] = field(
@@ -148,20 +156,26 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
image_size: Optional[int] = field(
default=None,
metadata={
- "help": "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
+ "help": (
+ "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
+ )
},
)
patch_size: Optional[int] = field(
default=None,
metadata={
- "help": "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
+ "help": (
+ "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
+ )
},
)
encoder_stride: Optional[int] = field(
@@ -225,6 +239,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mim", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py
index 04a6b4c26794..ca992c04562e 100755
--- a/examples/pytorch/language-modeling/run_clm.py
+++ b/examples/pytorch/language-modeling/run_clm.py
@@ -30,8 +30,9 @@
from typing import Optional
import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
CONFIG_MAPPING,
@@ -48,12 +49,12 @@
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -73,8 +74,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -84,8 +86,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -109,8 +113,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -141,24 +147,30 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -203,6 +215,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -390,7 +406,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
@@ -476,7 +493,7 @@ def preprocess_logits_for_metrics(logits, labels):
logits = logits[0]
return logits.argmax(dim=-1)
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py
index e9ac967c5681..3fd67d5fbf66 100755
--- a/examples/pytorch/language-modeling/run_clm_no_trainer.py
+++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py
@@ -45,7 +45,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -53,10 +52,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -94,7 +96,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -168,7 +170,11 @@ def parse_args():
"--block_size",
type=int,
default=None,
- help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).",
+ help=(
+ "Optional input sequence length after tokenization. The training dataset will be truncated in block of"
+ " this size for training. Default to the model max input length for single sentence inputs (take into"
+ " account special tokens)."
+ ),
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -202,7 +208,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -226,9 +242,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -447,24 +475,24 @@ def group_texts(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -474,7 +502,10 @@ def group_texts(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -484,12 +515,15 @@ def group_texts(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("clm_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("clm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -538,17 +572,20 @@ def group_texts(examples):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -568,20 +605,27 @@ def group_texts(examples):
outputs = model(**batch)
loss = outputs.loss
- losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
+ losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses)
- losses = losses[: len(eval_dataset)]
try:
- perplexity = math.exp(torch.mean(losses))
+ eval_loss = torch.mean(losses)
+ perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
- logger.info(f"epoch {epoch}: perplexity: {perplexity}")
+ logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
if args.with_tracking:
accelerator.log(
- {"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "perplexity": perplexity,
+ "eval_loss": eval_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py
index 477ccff95052..b635a7aea698 100755
--- a/examples/pytorch/language-modeling/run_mlm.py
+++ b/examples/pytorch/language-modeling/run_mlm.py
@@ -30,8 +30,9 @@
from typing import Optional
import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
CONFIG_MAPPING,
@@ -47,12 +48,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -70,8 +71,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -81,8 +83,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -106,8 +110,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -147,8 +153,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -165,22 +173,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -211,6 +225,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -498,7 +516,7 @@ def preprocess_logits_for_metrics(logits, labels):
logits = logits[0]
return logits.argmax(dim=-1)
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py
index d6a8c1691edb..80dfcf9a9194 100755
--- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py
+++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py
@@ -45,7 +45,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
@@ -53,10 +52,13 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@@ -97,7 +99,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -171,7 +173,9 @@ def parse_args():
"--max_seq_length",
type=int,
default=None,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
+ ),
)
parser.add_argument(
"--line_by_line",
@@ -211,7 +215,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -237,9 +251,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -490,7 +516,7 @@ def group_texts(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
@@ -500,17 +526,17 @@ def group_texts(examples):
# shorter in multiprocess)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -520,7 +546,10 @@ def group_texts(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -530,12 +559,15 @@ def group_texts(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("mlm_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("mlm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -584,17 +616,20 @@ def group_texts(examples):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -615,12 +650,12 @@ def group_texts(examples):
outputs = model(**batch)
loss = outputs.loss
- losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
+ losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses)
- losses = losses[: len(eval_dataset)]
try:
- perplexity = math.exp(torch.mean(losses))
+ eval_loss = torch.mean(losses)
+ perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
@@ -628,7 +663,14 @@ def group_texts(examples):
if args.with_tracking:
accelerator.log(
- {"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "perplexity": perplexity,
+ "eval_loss": eval_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py
index 8974882595ae..4a885ee49661 100755
--- a/examples/pytorch/language-modeling/run_plm.py
+++ b/examples/pytorch/language-modeling/run_plm.py
@@ -42,12 +42,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -63,8 +63,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
config_name: Optional[str] = field(
@@ -73,8 +74,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
tokenizer_name: Optional[str] = field(
@@ -95,8 +98,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -136,8 +141,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=512,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -147,8 +154,10 @@ class DataTrainingArguments:
plm_probability: float = field(
default=1 / 6,
metadata={
- "help": "Ratio of length of a span of masked tokens to surrounding context length for "
- "permutation language modeling."
+ "help": (
+ "Ratio of length of a span of masked tokens to surrounding context length for "
+ "permutation language modeling."
+ )
},
)
max_span_length: int = field(
@@ -161,22 +170,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -205,6 +220,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_plm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/multiple-choice/README.md b/examples/pytorch/multiple-choice/README.md
index 4e3e331e05de..735d1f5f33a0 100644
--- a/examples/pytorch/multiple-choice/README.md
+++ b/examples/pytorch/multiple-choice/README.md
@@ -53,7 +53,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py
index cd2bdd74ad2b..f9df919e1f92 100755
--- a/examples/pytorch/multiple-choice/run_swag.py
+++ b/examples/pytorch/multiple-choice/run_swag.py
@@ -43,11 +43,11 @@
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import PaddingStrategy, check_min_version
+from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
logger = logging.getLogger(__name__)
@@ -82,8 +82,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -109,30 +111,38 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If passed, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to the maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to the maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -215,6 +225,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py
index 756a0287eaa0..eeb04e417fdf 100755
--- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py
+++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py
@@ -31,10 +31,11 @@
import datasets
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -43,7 +44,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
@@ -52,9 +52,12 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import PaddingStrategy, get_full_repo_name
+from transformers.utils import PaddingStrategy, check_min_version, get_full_repo_name, send_example_telemetry
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@@ -62,7 +65,7 @@
def parse_args():
- parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
+ parser = argparse.ArgumentParser(description="Finetune a transformers model on a multiple choice task")
parser.add_argument(
"--dataset_name",
type=str,
@@ -99,7 +102,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -194,7 +197,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -264,9 +277,21 @@ def __call__(self, features):
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -447,24 +472,24 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
model.to(device)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -474,7 +499,10 @@ def preprocess_function(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -484,15 +512,18 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("swag_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("swag_no_trainer", experiment_config)
# Metrics
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -541,17 +572,20 @@ def preprocess_function(examples):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -566,19 +600,11 @@ def preprocess_function(examples):
break
model.eval()
- samples_seen = 0
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
- predictions, references = accelerator.gather((predictions, batch["labels"]))
- # If we are in a multiprocess environment, the last batch has duplicates
- if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
- predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
- references = references[: len(eval_dataloader.dataset) - samples_seen]
- else:
- samples_seen += references.shape[0]
+ predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
references=references,
@@ -589,7 +615,13 @@ def preprocess_function(examples):
if args.with_tracking:
accelerator.log(
- {"accuracy": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "accuracy": eval_metric,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/question-answering/README.md b/examples/pytorch/question-answering/README.md
index 480da1d89fdd..f6e660e972d6 100644
--- a/examples/pytorch/question-answering/README.md
+++ b/examples/pytorch/question-answering/README.md
@@ -136,7 +136,7 @@ SQuAD or a similar dataset, the main difference is that this script exposes the
You can use the script normally after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py
index 242e83427389..54db2b7bb12d 100755
--- a/examples/pytorch/question-answering/run_qa.py
+++ b/examples/pytorch/question-answering/run_qa.py
@@ -25,8 +25,9 @@
from typing import Optional
import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from trainer_qa import QuestionAnsweringTrainer
from transformers import (
@@ -42,13 +43,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -81,8 +82,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -118,37 +121,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -157,9 +169,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -173,8 +187,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -211,6 +227,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -319,9 +339,9 @@ def main():
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Preprocessing the datasets.
@@ -574,7 +594,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py
index d46e96d21043..ce110ae36463 100755
--- a/examples/pytorch/question-answering/run_qa_beam_search.py
+++ b/examples/pytorch/question-answering/run_qa_beam_search.py
@@ -25,8 +25,9 @@
from typing import Optional
import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from trainer_qa import QuestionAnsweringTrainer
from transformers import (
@@ -41,13 +42,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -80,8 +81,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -117,37 +120,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -156,9 +168,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -172,8 +186,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -210,6 +226,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_beam_search", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -606,7 +626,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
index 1da2f89ed94e..370dd3f43d95 100644
--- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
+++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
@@ -29,10 +29,11 @@
import datasets
import numpy as np
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -49,13 +50,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -116,8 +117,10 @@ def parse_args():
"--max_seq_length",
type=int,
default=384,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+ " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -190,9 +193,11 @@ def parse_args():
"--null_score_diff_threshold",
type=float,
default=0.0,
- help="The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`.",
+ help=(
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ ),
)
parser.add_argument(
"--version_2_with_negative",
@@ -203,22 +208,28 @@ def parse_args():
"--max_answer_length",
type=int,
default=30,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
@@ -281,9 +292,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_beam_search_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -666,7 +689,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if args.version_2_with_negative else "squad")
def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"""
@@ -683,7 +706,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
step = 0
# create a numpy array and fill it with -100.
logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float32)
- # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather
+ # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather_for_metrics
for i, output_logit in enumerate(start_or_end_logits): # populate columns
# We have to fill it such that we have to take the whole tensor and replace it on the newly created array
# And after every iteration we have to change the step
@@ -715,17 +738,17 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -735,7 +758,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -800,17 +826,22 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+
+ accelerator.backward(loss)
+
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -858,11 +889,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
end_top_index = accelerator.pad_across_processes(end_top_index, dim=1, pad_index=-100)
cls_logits = accelerator.pad_across_processes(cls_logits, dim=1, pad_index=-100)
- all_start_top_log_probs.append(accelerator.gather(start_top_log_probs).cpu().numpy())
- all_start_top_index.append(accelerator.gather(start_top_index).cpu().numpy())
- all_end_top_log_probs.append(accelerator.gather(end_top_log_probs).cpu().numpy())
- all_end_top_index.append(accelerator.gather(end_top_index).cpu().numpy())
- all_cls_logits.append(accelerator.gather(cls_logits).cpu().numpy())
+ all_start_top_log_probs.append(accelerator.gather_for_metrics(start_top_log_probs).cpu().numpy())
+ all_start_top_index.append(accelerator.gather_for_metrics(start_top_index).cpu().numpy())
+ all_end_top_log_probs.append(accelerator.gather_for_metrics(end_top_log_probs).cpu().numpy())
+ all_end_top_index.append(accelerator.gather_for_metrics(end_top_index).cpu().numpy())
+ all_cls_logits.append(accelerator.gather_for_metrics(cls_logits).cpu().numpy())
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
@@ -918,11 +949,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
end_top_index = accelerator.pad_across_processes(end_top_index, dim=1, pad_index=-100)
cls_logits = accelerator.pad_across_processes(cls_logits, dim=1, pad_index=-100)
- all_start_top_log_probs.append(accelerator.gather(start_top_log_probs).cpu().numpy())
- all_start_top_index.append(accelerator.gather(start_top_index).cpu().numpy())
- all_end_top_log_probs.append(accelerator.gather(end_top_log_probs).cpu().numpy())
- all_end_top_index.append(accelerator.gather(end_top_index).cpu().numpy())
- all_cls_logits.append(accelerator.gather(cls_logits).cpu().numpy())
+ all_start_top_log_probs.append(accelerator.gather_for_metrics(start_top_log_probs).cpu().numpy())
+ all_start_top_index.append(accelerator.gather_for_metrics(start_top_index).cpu().numpy())
+ all_end_top_log_probs.append(accelerator.gather_for_metrics(end_top_log_probs).cpu().numpy())
+ all_end_top_index.append(accelerator.gather_for_metrics(end_top_index).cpu().numpy())
+ all_cls_logits.append(accelerator.gather_for_metrics(cls_logits).cpu().numpy())
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py
index c0f47e4526fa..6bf4eb28e994 100755
--- a/examples/pytorch/question-answering/run_qa_no_trainer.py
+++ b/examples/pytorch/question-answering/run_qa_no_trainer.py
@@ -29,10 +29,11 @@
import datasets
import numpy as np
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -41,7 +42,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
@@ -51,13 +51,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -121,8 +121,10 @@ def parse_args():
"--max_seq_length",
type=int,
default=384,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+ " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -133,7 +135,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -212,9 +214,11 @@ def parse_args():
"--null_score_diff_threshold",
type=float,
default=0.0,
- help="The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`.",
+ help=(
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ ),
)
parser.add_argument(
"--version_2_with_negative",
@@ -225,22 +229,28 @@ def parse_args():
"--max_answer_length",
type=int,
default=30,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
@@ -278,7 +288,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -310,9 +330,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -670,7 +702,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if args.version_2_with_negative else "squad")
# Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
@@ -688,7 +720,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
step = 0
# create a numpy array and fill it with -100.
logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64)
- # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather
+ # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather_for_metrics
for i, output_logit in enumerate(start_or_end_logits): # populate columns
# We have to fill it such that we have to take the whole tensor and replace it on the newly created array
# And after every iteration we have to change the step
@@ -718,20 +750,20 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -741,7 +773,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -751,12 +786,15 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("qa_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("qa_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -806,17 +844,21 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -868,8 +910,8 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
- all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
- all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())
+ all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
+ all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
@@ -907,8 +949,8 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
- all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
- all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())
+ all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
+ all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
# concatenate the numpy array
@@ -927,14 +969,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if args.with_tracking:
log = {
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
}
if args.do_predict:
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric
- accelerator.log(log)
+ accelerator.log(log, step=completed_steps)
if args.output_dir is not None:
accelerator.wait_for_everyone()
diff --git a/examples/pytorch/question-answering/run_seq2seq_qa.py b/examples/pytorch/question-answering/run_seq2seq_qa.py
index cb6bd09bc40d..8ffe114dbb86 100644
--- a/examples/pytorch/question-answering/run_seq2seq_qa.py
+++ b/examples/pytorch/question-answering/run_seq2seq_qa.py
@@ -25,8 +25,9 @@
from typing import List, Optional, Tuple
import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
from transformers import (
@@ -39,12 +40,12 @@
set_seed,
)
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -81,8 +82,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -130,53 +133,66 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
val_max_answer_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -185,9 +201,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -201,8 +219,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -252,6 +272,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_seq2seq_qa", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -430,9 +454,8 @@ def preprocess_function(examples):
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
+ # Tokenize targets with text_target=...
+ labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -456,9 +479,8 @@ def preprocess_validation_function(examples):
return_overflowing_tokens=True,
return_offsets_mapping=True,
)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
@@ -560,7 +582,7 @@ def preprocess_validation_function(examples):
pad_to_multiple_of=8 if training_args.fp16 else None,
)
- metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py
index 7f98eba236c1..59d7a084c108 100644
--- a/examples/pytorch/question-answering/trainer_qa.py
+++ b/examples/pytorch/question-answering/trainer_qa.py
@@ -20,7 +20,7 @@
from transformers.trainer_utils import PredictionOutput
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py
index ac260dadbc33..6ad66aeec5b4 100644
--- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py
+++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py
@@ -23,7 +23,7 @@
from transformers.trainer_utils import PredictionOutput
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
@@ -41,11 +41,16 @@ def evaluate(
eval_examples=None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
- max_length: Optional[int] = None,
- num_beams: Optional[int] = None,
+ **gen_kwargs,
) -> Dict[str, float]:
- self._max_length = max_length if max_length is not None else self.args.generation_max_length
- self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
+ gen_kwargs = gen_kwargs.copy()
+ gen_kwargs["max_length"] = (
+ gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
+ )
+ gen_kwargs["num_beams"] = (
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
+ )
+ self._gen_kwargs = gen_kwargs
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
@@ -87,7 +92,11 @@ def evaluate(
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
- def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
+ def predict(
+ self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs
+ ):
+ self._gen_kwargs = gen_kwargs.copy()
+
predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here.
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
index 304f8848b49b..bc1bfb2c1c09 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
@@ -21,7 +21,6 @@
from dataclasses import dataclass, field
from typing import Optional
-import datasets
import numpy as np
import torch
from datasets import load_dataset
@@ -30,6 +29,7 @@
from torchvision import transforms
from torchvision.transforms import functional
+import evaluate
import transformers
from huggingface_hub import hf_hub_download
from transformers import (
@@ -42,7 +42,7 @@
default_data_collator,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -51,7 +51,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
@@ -194,15 +194,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
reduce_labels: Optional[bool] = field(
@@ -241,8 +245,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -260,6 +266,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_semantic_segmentation", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -327,7 +337,7 @@ def main():
label2id = {v: str(k) for k, v in id2label.items()}
# Load the mean IoU metric from the datasets package
- metric = datasets.load_metric("mean_iou")
+ metric = evaluate.load("mean_iou")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
index 979d5e5ca4b1..30cb7cc53ae3 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
@@ -24,13 +24,14 @@
import datasets
import numpy as np
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -44,10 +45,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
@@ -285,7 +289,17 @@ def parse_args():
"--with_tracking",
required=False,
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -305,9 +319,21 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_semantic_segmentation_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
+
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
@@ -457,17 +483,17 @@ def preprocess_val(example_batch):
checkpointing_steps = None
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -477,16 +503,23 @@ def preprocess_val(example_batch):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Instantiate metric
- metric = load_metric("mean_iou")
+ metric = evaluate.load("mean_iou")
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -535,17 +568,20 @@ def preprocess_val(example_batch):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -577,7 +613,6 @@ def preprocess_val(example_batch):
logger.info("***** Running evaluation *****")
model.eval()
- samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
outputs = model(**batch)
@@ -587,15 +622,7 @@ def preprocess_val(example_batch):
)
predictions = upsampled_logits.argmax(dim=1)
- predictions, references = accelerator.gather((predictions, batch["labels"]))
-
- # If we are in a multiprocess environment, the last batch has duplicates
- if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
- predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
- references = references[: len(eval_dataloader.dataset) - samples_seen]
- else:
- samples_seen += references.shape[0]
+ predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
@@ -615,10 +642,11 @@ def preprocess_val(example_batch):
"mean_iou": eval_metrics["mean_iou"],
"mean_accuracy": eval_metrics["mean_accuracy"],
"overall_accuracy": eval_metrics["overall_accuracy"],
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/speech-pretraining/README.md b/examples/pytorch/speech-pretraining/README.md
index fc8e16623aa3..1d57fc8e72df 100644
--- a/examples/pytorch/speech-pretraining/README.md
+++ b/examples/pytorch/speech-pretraining/README.md
@@ -43,7 +43,7 @@ A good metric to observe during training is the gradient norm which should ideal
When training a model on large datasets it is recommended to run the data preprocessing
in a first run in a **non-distributed** mode via `--preprocessing_only` so that
-when running the model in **distributed** mode in a second step the preprocessed data
+when running the model in **distributed** mode in a second step the preprocessed data
can easily be loaded on each distributed device.
---
diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
index a66d1f54939c..a3db215d08bd 100755
--- a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
+++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
@@ -43,7 +43,7 @@
set_seed,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
logger = get_logger(__name__)
@@ -219,7 +219,10 @@ def parse_args():
"--pad_to_multiple_of",
type=int,
default=None,
- help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
+ help=(
+ "If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the"
+ " use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
+ ),
)
parser.add_argument(
"--adam_beta1",
@@ -360,6 +363,10 @@ def main():
# We now keep distinct sets of args, for a cleaner separation of concerns.
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_wav2vec2_pretraining_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
logger.info(accelerator.state, main_process_only=False)
@@ -440,7 +447,7 @@ def main():
# only normalized-inputs-training is supported
if not feature_extractor.do_normalize:
raise ValueError(
- "Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
+ "Training is only supported for normalized inputs. Make sure ``feature_extractor.do_normalize == True``"
)
# set max & min audio length in number of samples
@@ -496,7 +503,8 @@ def prepare_dataset(batch):
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
# initialize random model
@@ -538,8 +546,6 @@ def prepare_dataset(batch):
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
@@ -548,6 +554,9 @@ def prepare_dataset(batch):
num_training_steps=args.max_train_steps,
)
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
# 5. Train
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -615,7 +624,7 @@ def prepare_dataset(batch):
lr_scheduler.step()
elif accelerator.is_local_main_process:
progress_bar.write(
- "Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
+ f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
)
# update gumbel temperature
diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
index 6df37086240d..36efb44138d9 100755
--- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
@@ -28,8 +28,9 @@
import datasets
import numpy as np
import torch
-from datasets import DatasetDict, load_dataset, load_metric
+from datasets import DatasetDict, load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -44,12 +45,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
@@ -101,9 +102,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -113,8 +116,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -146,8 +152,10 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train+validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'train+validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to "
+ "'train+validation'"
+ )
},
)
eval_split_name: str = field(
@@ -174,15 +182,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -196,7 +208,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -205,17 +220,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -233,10 +252,12 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -285,13 +306,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
return_tensors="pt",
)
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -356,6 +376,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_speech_recognition_ctc", model_args, data_args)
+
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
@@ -405,9 +429,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.text_column_name not in raw_datasets["train"].column_names:
@@ -481,7 +505,12 @@ def remove_special_characters(batch):
with training_args.main_process_first():
if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
- os.remove(vocab_file)
+ try:
+ os.remove(vocab_file)
+ except OSError:
+ # in shared file-systems it might be the case that
+ # two processes try to delete the vocab file at the some time
+ pass
with training_args.main_process_first(desc="dataset map vocabulary creation"):
if not os.path.isfile(vocab_file):
@@ -615,7 +644,7 @@ def is_audio_in_length_range(length):
# instantiate a data collator and the trainer
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
- eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
+ eval_metrics = {metric: evaluate.load(metric) for metric in data_args.eval_metrics}
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
@@ -720,7 +749,10 @@ def compute_metrics(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
index 3c368c4ae836..015c1f0a6532 100755
--- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
+++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
@@ -27,8 +27,9 @@
import datasets
import torch
-from datasets import DatasetDict, load_dataset, load_metric
+from datasets import DatasetDict, load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -42,12 +43,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
@@ -87,8 +88,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_feature_encoder: bool = field(
@@ -122,15 +125,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
audio_column_name: str = field(
@@ -144,7 +151,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -153,10 +163,12 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
train_split_name: str = field(
@@ -228,6 +240,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args)
+
# 2. Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -410,7 +426,7 @@ def is_audio_in_length_range(length):
return
# 8. Load Metric
- metric = load_metric("wer")
+ metric = evaluate.load("wer")
def compute_metrics(pred):
pred_ids = pred.predictions
diff --git a/examples/pytorch/summarization/README.md b/examples/pytorch/summarization/README.md
index bf42e796434e..db7f8f4061a5 100644
--- a/examples/pytorch/summarization/README.md
+++ b/examples/pytorch/summarization/README.md
@@ -149,7 +149,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py
index c35b636d7dd9..5d6d5d5c771b 100755
--- a/examples/pytorch/summarization/run_summarization.py
+++ b/examples/pytorch/summarization/run_summarization.py
@@ -27,8 +27,9 @@
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from filelock import FileLock
from transformers import (
@@ -46,12 +47,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -101,15 +102,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
- "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
- "the model's position embeddings."
+ "help": (
+ "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
+ "the model's position embeddings."
+ )
},
)
@@ -120,7 +125,7 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""
- lang: str = field(default=None, metadata={"help": "Language id for summarization."})
+ lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
@@ -142,14 +147,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -162,60 +168,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -231,9 +253,11 @@ class DataTrainingArguments:
forced_bos_token: Optional[str] = field(
default=None,
metadata={
- "help": "The token to force as the first generated token after the decoder_start_token_id."
- "Useful for multilingual models like mBART where the first generated token"
- "needs to be the target language token (Usually it is the target language token)"
+ "help": (
+ "The token to force as the first generated token after the decoder_start_token_id."
+ "Useful for multilingual models like mBART where the first generated token"
+ "needs to be the target language token (Usually it is the target language token)"
+ )
},
)
@@ -263,6 +287,7 @@ def __post_init__(self):
"xglue": ("news_body", "news_title"),
"xsum": ("document", "summary"),
"wiki_summary": ("article", "highlights"),
+ "multi_news": ("document", "summary"),
}
@@ -279,6 +304,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -410,17 +439,18 @@ def main():
):
if model_args.resize_position_embeddings is None:
logger.warning(
- f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} "
- f"to {data_args.max_source_length}."
+ "Increasing the model's number of position embedding vectors from"
+ f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
)
model.resize_position_embeddings(data_args.max_source_length)
elif model_args.resize_position_embeddings:
model.resize_position_embeddings(data_args.max_source_length)
else:
raise ValueError(
- f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
- f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
- "resize the model's position encodings by passing `--resize_position_embeddings`."
+ f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
+ f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
+ f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
+ " model's position encodings by passing `--resize_position_embeddings`."
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
@@ -486,16 +516,15 @@ def preprocess_function(examples):
inputs, targets = [], []
for i in range(len(examples[text_column])):
- if examples[text_column][i] is not None and examples[summary_column][i] is not None:
+ if examples[text_column][i] and examples[summary_column][i]:
inputs.append(examples[text_column][i])
targets.append(examples[summary_column][i])
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -570,7 +599,7 @@ def preprocess_function(examples):
)
# Metric
- metric = load_metric("rouge")
+ metric = evaluate.load("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
@@ -596,12 +625,9 @@ def compute_metrics(eval_preds):
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
- # Extract a few results from ROUGE
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
-
+ result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
- result = {k: round(v, 4) for k, v in result.items()}
return result
# Initialize our Trainer
diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py
index 59ec178c974d..96781b6dcadb 100644
--- a/examples/pytorch/summarization/run_summarization_no_trainer.py
+++ b/examples/pytorch/summarization/run_summarization_no_trainer.py
@@ -30,10 +30,11 @@
import nltk
import numpy as np
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -43,7 +44,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@@ -51,10 +51,13 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name, is_offline_mode
+from transformers.utils import check_min_version, get_full_repo_name, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -111,20 +114,22 @@ def parse_args():
"--ignore_pad_token_for_loss",
type=bool,
default=True,
- help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
+ help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
- help="The maximum total input sequence length after "
- "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after "
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--source_prefix",
type=str,
default=None,
- help="A prefix to add before every source text " "(useful for T5 models).",
+ help="A prefix to add before every source text (useful for T5 models).",
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -139,18 +144,22 @@ def parse_args():
"--max_target_length",
type=int,
default=128,
- help="The maximum total sequence length for target text after "
- "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
- "during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for target text after "
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+ "during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--val_max_target_length",
type=int,
default=None,
- help="The maximum total sequence length for validation "
- "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
- "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
- "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for validation "
+ "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+ "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+ "param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--max_length",
@@ -165,8 +174,10 @@ def parse_args():
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -177,7 +188,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -279,7 +290,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -302,7 +323,20 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization_no_trainer", args)
+
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator_log_kwargs = {}
+ if args.with_tracking:
+ accelerator_log_kwargs["log_with"] = args.report_to
+ accelerator_log_kwargs["logging_dir"] = args.output_dir
+
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
if args.source_prefix is None and args.model_name_or_path in [
"t5-small",
"t5-base",
@@ -314,9 +348,6 @@ def main():
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
- # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -447,9 +478,8 @@ def preprocess_function(examples):
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -503,7 +533,7 @@ def postprocess_text(preds, labels):
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
- no_decay = ["bias", "LayerNorm.weight"]
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
@@ -514,20 +544,20 @@ def postprocess_text(preds, labels):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
@@ -537,7 +567,10 @@ def postprocess_text(preds, labels):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -547,15 +580,18 @@ def postprocess_text(preds, labels):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("summarization_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("summarization_no_trainer", experiment_config)
# Metric
- metric = load_metric("rouge")
+ metric = evaluate.load("rouge")
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -603,17 +639,20 @@ def postprocess_text(preds, labels):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
- outputs = model(**batch)
- loss = outputs.loss
- # We keep track of the loss at each epoch
- if args.with_tracking:
- total_loss += loss.detach().float()
- loss = loss / args.gradient_accumulation_steps
- accelerator.backward(loss)
- if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ with accelerator.accumulate(model):
+ outputs = model(**batch)
+ loss = outputs.loss
+ # We keep track of the loss at each epoch
+ if args.with_tracking:
+ total_loss += loss.detach().float()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
@@ -667,29 +706,26 @@ def postprocess_text(preds, labels):
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
- samples_seen += decoded_labels.shape[0]
+ samples_seen += len(decoded_labels)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
result = metric.compute(use_stemmer=True)
- # Extract a few results from ROUGE
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
-
- result = {k: round(v, 4) for k, v in result.items()}
+ result = {k: round(v * 100, 4) for k, v in result.items()}
logger.info(result)
if args.with_tracking:
- result["train_loss"] = total_loss
+ result["train_loss"] = total_loss.item() / len(train_dataloader)
result["epoch"] = epoch
result["step"] = completed_steps
- accelerator.log(result)
+ accelerator.log(result, step=completed_steps)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py
index 14eef9c7f772..99a8b0db84a0 100644
--- a/examples/pytorch/test_accelerate_examples.py
+++ b/examples/pytorch/test_accelerate_examples.py
@@ -18,49 +18,18 @@
import json
import logging
import os
+import shutil
import sys
-from unittest.mock import patch
+import tempfile
+from unittest import mock
import torch
-from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
+from accelerate.utils import write_basic_config
+from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
from transformers.utils import is_apex_available
-SRC_DIRS = [
- os.path.join(os.path.dirname(__file__), dirname)
- for dirname in [
- "text-generation",
- "text-classification",
- "token-classification",
- "language-modeling",
- "multiple-choice",
- "question-answering",
- "summarization",
- "translation",
- "image-classification",
- "speech-recognition",
- "audio-classification",
- "speech-pretraining",
- "image-pretraining",
- "semantic-segmentation",
- ]
-]
-sys.path.extend(SRC_DIRS)
-
-
-if SRC_DIRS is not None:
- import run_clm_no_trainer
- import run_glue_no_trainer
- import run_image_classification_no_trainer
- import run_mlm_no_trainer
- import run_ner_no_trainer
- import run_qa_no_trainer as run_squad_no_trainer
- import run_semantic_segmentation_no_trainer
- import run_summarization_no_trainer
- import run_swag_no_trainer
- import run_translation_no_trainer
-
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
@@ -94,10 +63,23 @@ def is_cuda_and_apex_available():
class ExamplesTestsNoTrainer(TestCasePlus):
+ @classmethod
+ def setUpClass(cls):
+ # Write Accelerate config, will pick up on CPU, GPU, and multi-GPU
+ cls.tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls.tmpdir, "default_config.yml")
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmpdir)
+
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_glue_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_glue_no_trainer.py
+ {self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
@@ -113,17 +95,17 @@ def test_run_glue_no_trainer(self):
if is_cuda_and_apex_available():
testargs.append("--fp16")
- with patch.object(sys, "argv", testargs):
- run_glue_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.75)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.75)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_clm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_clm_no_trainer.py
+ {self.examples_dir}/pytorch/language-modeling/run_clm_no_trainer.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
@@ -140,17 +122,17 @@ def test_run_clm_no_trainer(self):
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
- with patch.object(sys, "argv", testargs):
- run_clm_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertLess(result["perplexity"], 100)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertLess(result["perplexity"], 100)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_mlm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_mlm_no_trainer.py
+ {self.examples_dir}/pytorch/language-modeling/run_mlm_no_trainer.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
@@ -160,20 +142,20 @@ def test_run_mlm_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_mlm_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertLess(result["perplexity"], 42)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertLess(result["perplexity"], 42)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_ner_no_trainer.py
+ {self.examples_dir}/pytorch/token-classification/run_ner_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
@@ -187,18 +169,18 @@ def test_run_ner_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_ner_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.75)
- self.assertLess(result["train_loss"], 0.5)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.75)
+ self.assertLess(result["train_loss"], 0.5)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_qa_no_trainer.py
+ {self.examples_dir}/pytorch/question-answering/run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
@@ -213,19 +195,19 @@ def test_run_squad_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_squad_no_trainer.main()
- result = get_results(tmp_dir)
- # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
- self.assertGreaterEqual(result["eval_f1"], 30)
- self.assertGreaterEqual(result["eval_exact"], 30)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
+ self.assertGreaterEqual(result["eval_f1"], 28)
+ self.assertGreaterEqual(result["eval_exact"], 28)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_swag_no_trainer.py
+ {self.examples_dir}/pytorch/multiple-choice/run_swag_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/swag/sample.json
--validation_file tests/fixtures/tests_samples/swag/sample.json
@@ -238,17 +220,17 @@ def test_run_swag_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_swag_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.8)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.8)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
@slow
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_summarization_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_summarization_no_trainer.py
+ {self.examples_dir}/pytorch/summarization/run_summarization_no_trainer.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
@@ -262,21 +244,21 @@ def test_run_summarization_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_summarization_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_rouge1"], 10)
- self.assertGreaterEqual(result["eval_rouge2"], 2)
- self.assertGreaterEqual(result["eval_rougeL"], 7)
- self.assertGreaterEqual(result["eval_rougeLsum"], 7)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_rouge1"], 10)
+ self.assertGreaterEqual(result["eval_rouge2"], 2)
+ self.assertGreaterEqual(result["eval_rougeL"], 7)
+ self.assertGreaterEqual(result["eval_rougeLsum"], 7)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
@slow
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_translation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_translation_no_trainer.py
+ {self.examples_dir}/pytorch/translation/run_translation_no_trainer.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--source_lang en
--target_lang ro
@@ -294,12 +276,11 @@ def test_run_translation_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_translation_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_bleu"], 30)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_bleu"], 30)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
@slow
def test_run_semantic_segmentation_no_trainer(self):
@@ -308,7 +289,7 @@ def test_run_semantic_segmentation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_semantic_segmentation_no_trainer.py
+ {self.examples_dir}/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
--dataset_name huggingface/semantic-segmentation-test-sample
--output_dir {tmp_dir}
--max_train_steps=10
@@ -319,29 +300,34 @@ def test_run_semantic_segmentation_no_trainer(self):
--checkpointing_steps epoch
""".split()
- with patch.object(sys, "argv", testargs):
- run_semantic_segmentation_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
+ @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_image_classification_no_trainer.py
- --dataset_name huggingface/image-classification-test-sample
+ {self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
+ --model_name_or_path google/vit-base-patch16-224-in21k
+ --dataset_name hf-internal-testing/cats_vs_dogs_sample
+ --learning_rate 1e-4
+ --per_device_train_batch_size 2
+ --per_device_eval_batch_size 1
+ --max_train_steps 2
+ --train_val_split 0.1
+ --seed 42
--output_dir {tmp_dir}
- --num_warmup_steps=8
- --learning_rate=3e-3
- --per_device_train_batch_size=2
- --per_device_eval_batch_size=1
- --checkpointing_steps epoch
--with_tracking
- --seed 42
+ --checkpointing_steps 1
""".split()
- with patch.object(sys, "argv", testargs):
- run_image_classification_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.50)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
+ if is_cuda_and_apex_available():
+ testargs.append("--fp16")
+
+ run_command(self._launch_args + testargs)
+ result = get_results(tmp_dir)
+ # The base model scores a 25%
+ self.assertGreaterEqual(result["eval_accuracy"], 0.6)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "step_1")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
diff --git a/examples/pytorch/text-classification/README.md b/examples/pytorch/text-classification/README.md
index 5f853149e346..391aaf4d3f03 100644
--- a/examples/pytorch/text-classification/README.md
+++ b/examples/pytorch/text-classification/README.md
@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
-and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
+and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
(the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
@@ -79,6 +79,8 @@ python run_glue.py \
--output_dir /tmp/imdb/
```
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
### Mixed precision training
@@ -115,7 +117,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py
index b15a0378ca7d..49af0c85568c 100755
--- a/examples/pytorch/text-classification/run_glue.py
+++ b/examples/pytorch/text-classification/run_glue.py
@@ -25,8 +25,9 @@
import datasets
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -42,12 +43,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -89,8 +90,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -99,29 +102,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -180,10 +191,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def main():
@@ -199,6 +216,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -352,6 +373,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the raw_datasets
@@ -459,9 +481,9 @@ def preprocess_function(examples):
# Get the metric function
if data_args.task_name is not None:
- metric = load_metric("glue", data_args.task_name)
+ metric = evaluate.load("glue", data_args.task_name)
else:
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py
index 38017e77db13..f74e5520699b 100644
--- a/examples/pytorch/text-classification/run_glue_no_trainer.py
+++ b/examples/pytorch/text-classification/run_glue_no_trainer.py
@@ -23,17 +23,17 @@
import datasets
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
- AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
@@ -43,10 +43,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -168,7 +171,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -191,10 +209,16 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue_no_trainer", args)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -253,7 +277,7 @@ def main():
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
- extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
+ extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -288,6 +312,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
@@ -325,7 +350,7 @@ def main():
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
- elif args.task_name is None:
+ elif args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)}
if label_to_id is not None:
@@ -397,14 +422,14 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
@@ -418,9 +443,12 @@ def preprocess_function(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -430,18 +458,21 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("glue_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("glue_no_trainer", experiment_config)
# Get the metric function
if args.task_name is not None:
- metric = load_metric("glue", args.task_name)
+ metric = evaluate.load("glue", args.task_name)
else:
- metric = load_metric("accuracy")
+ metric = evaluate.load("accuracy")
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -522,7 +553,7 @@ def preprocess_function(examples):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -539,10 +570,11 @@ def preprocess_function(examples):
accelerator.log(
{
"accuracy" if args.task_name is not None else "glue": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/text-classification/run_xnli.py b/examples/pytorch/text-classification/run_xnli.py
index cd4d44b6a61e..d4cfc3a77d0b 100755
--- a/examples/pytorch/text-classification/run_xnli.py
+++ b/examples/pytorch/text-classification/run_xnli.py
@@ -26,8 +26,9 @@
import datasets
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -42,12 +43,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -67,8 +68,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -77,33 +80,39 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
- server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
- server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
@dataclass
@@ -146,10 +155,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def main():
@@ -160,14 +175,9 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
- # Setup distant debugging if needed
- if data_args.server_ip and data_args.server_port:
- # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
- import ptvsd
-
- print("Waiting for debugger attach")
- ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
- ptvsd.wait_for_attach()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_xnli", model_args)
# Setup logging
logging.basicConfig(
@@ -279,6 +289,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
@@ -339,7 +350,7 @@ def preprocess_function(examples):
)
# Get the metric function
- metric = load_metric("xnli")
+ metric = evaluate.load("xnli")
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
diff --git a/examples/pytorch/token-classification/README.md b/examples/pytorch/token-classification/README.md
index 01f586dff2fe..496722cf6b9a 100644
--- a/examples/pytorch/token-classification/README.md
+++ b/examples/pytorch/token-classification/README.md
@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
of the script.
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
## Old version of the script
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).
@@ -73,7 +75,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py
index fc54c77d6204..9000b5006e03 100755
--- a/examples/pytorch/token-classification/run_ner.py
+++ b/examples/pytorch/token-classification/run_ner.py
@@ -27,8 +27,9 @@
import datasets
import numpy as np
-from datasets import ClassLabel, load_dataset, load_metric
+from datasets import ClassLabel, load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -43,12 +44,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
@@ -81,10 +82,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
@dataclass
@@ -127,44 +134,56 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -198,6 +217,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -350,14 +373,15 @@ def get_label_list(labels):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Model has labels -> use them.
@@ -373,8 +397,8 @@ def get_label_list(labels):
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
# Set the correspondences label/ID inside the model config
@@ -481,7 +505,7 @@ def tokenize_and_align_labels(examples):
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
# Metrics
- metric = load_metric("seqeval")
+ metric = evaluate.load("seqeval")
def compute_metrics(p):
predictions, labels = p
diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py
index 234109b5d966..f5736f35c791 100755
--- a/examples/pytorch/token-classification/run_ner_no_trainer.py
+++ b/examples/pytorch/token-classification/run_ner_no_trainer.py
@@ -28,10 +28,11 @@
import datasets
import torch
-from datasets import ClassLabel, load_dataset, load_metric
+from datasets import ClassLabel, load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -40,7 +41,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
@@ -50,10 +50,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
@@ -114,7 +117,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -221,7 +224,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -245,9 +263,16 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -383,6 +408,7 @@ def get_label_list(labels):
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
else:
logger.info("Training new model from scratch")
@@ -403,8 +429,8 @@ def get_label_list(labels):
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
# Set the correspondences label/ID inside the model config
@@ -507,18 +533,18 @@ def tokenize_and_align_labels(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
model.to(device)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
@@ -534,7 +560,10 @@ def tokenize_and_align_labels(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
@@ -544,15 +573,18 @@ def tokenize_and_align_labels(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("ner_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("ner_no_trainer", experiment_config)
# Metrics
- metric = load_metric("seqeval")
+ metric = evaluate.load("seqeval")
def get_labels(predictions, references):
# Transform predictions and references tensos to numpy arrays
@@ -677,7 +709,7 @@ def compute_metrics():
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -692,7 +724,13 @@ def compute_metrics():
accelerator.print(f"epoch {epoch}:", eval_metric)
if args.with_tracking:
accelerator.log(
- {"seqeval": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "seqeval": eval_metric,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
@@ -725,7 +763,9 @@ def compute_metrics():
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
- json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f)
+ json.dump(
+ {"eval_accuracy": eval_metric["accuracy"], "train_loss": total_loss.item() / len(train_dataloader)}, f
+ )
if __name__ == "__main__":
diff --git a/examples/pytorch/translation/README.md b/examples/pytorch/translation/README.md
index 00c03a9be139..4bd66ea0acd1 100644
--- a/examples/pytorch/translation/README.md
+++ b/examples/pytorch/translation/README.md
@@ -162,7 +162,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py
index 6f2630104f7e..af1868b25aad 100755
--- a/examples/pytorch/translation/run_translation.py
+++ b/examples/pytorch/translation/run_translation.py
@@ -26,8 +26,9 @@
import datasets
import numpy as np
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -46,12 +47,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
@@ -91,8 +92,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -116,15 +119,12 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
- "a jsonlines file."
+ "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
- metadata={
- "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
- },
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
@@ -136,60 +136,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -204,9 +220,11 @@ class DataTrainingArguments:
forced_bos_token: Optional[str] = field(
default=None,
metadata={
- "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
- "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
- "needs to be the target language token.(Usually it is the target language token)"
+ "help": (
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
+ " be the target language token.(Usually it is the target language token)"
+ )
},
)
@@ -243,6 +261,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -422,9 +444,8 @@ def preprocess_function(examples):
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -502,7 +523,7 @@ def preprocess_function(examples):
)
# Metric
- metric = load_metric("sacrebleu")
+ metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py
index 21eadf6aaee7..a6b0988f63d0 100644
--- a/examples/pytorch/translation/run_translation_no_trainer.py
+++ b/examples/pytorch/translation/run_translation_no_trainer.py
@@ -29,10 +29,11 @@
import datasets
import numpy as np
import torch
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -41,7 +42,6 @@
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@@ -52,10 +52,13 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
@@ -95,41 +98,51 @@ def parse_args():
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
- help="The maximum total input sequence length after "
- "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after "
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
type=int,
default=128,
- help="The maximum total sequence length for target text after "
- "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
- "during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for target text after "
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+ "during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--val_max_target_length",
type=int,
default=None,
- help="The maximum total sequence length for validation "
- "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
- "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
- "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for validation "
+ "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+ "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+ "param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
type=bool,
default=False,
- help="Whether to pad all samples to model maximum sentence "
- "length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
- "efficient on GPU but very bad for TPU.",
+ help=(
+ "Whether to pad all samples to model maximum sentence "
+ "length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
+ "efficient on GPU but very bad for TPU."
+ ),
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
@@ -138,7 +151,7 @@ def parse_args():
"--ignore_pad_token_for_loss",
type=bool,
default=True,
- help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
+ help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.")
parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.")
@@ -146,7 +159,7 @@ def parse_args():
"--source_prefix",
type=str,
default=None,
- help="A prefix to add before every source text " "(useful for T5 models).",
+ help="A prefix to add before every source text (useful for T5 models).",
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -170,7 +183,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -260,7 +273,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -286,9 +309,16 @@ def main():
# Parse the arguments
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -426,9 +456,8 @@ def preprocess_function(examples):
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -492,14 +521,14 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- else:
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
@@ -515,8 +544,10 @@ def preprocess_function(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
-
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -525,14 +556,17 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("translation_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("translation_no_trainer", experiment_config)
- metric = load_metric("sacrebleu")
+ metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
@@ -651,11 +685,11 @@ def postprocess_text(preds, labels):
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
- samples_seen += decoded_labels.shape[0]
+ samples_seen += len(decoded_labels)
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
eval_metric = metric.compute()
@@ -663,7 +697,13 @@ def postprocess_text(preds, labels):
if args.with_tracking:
accelerator.log(
- {"blue": eval_metric["score"], "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "bleu": eval_metric["score"],
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/xla_spawn.py b/examples/pytorch/xla_spawn.py
index d84b41994564..5df6bfa2d5dc 100644
--- a/examples/pytorch/xla_spawn.py
+++ b/examples/pytorch/xla_spawn.py
@@ -39,9 +39,7 @@ def parse_args():
"""
parser = ArgumentParser(
description=(
- "PyTorch TPU distributed training launch "
- "helper utility that will spawn up "
- "multiple distributed processes"
+ "PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)
diff --git a/examples/research_projects/adversarial/run_hans.py b/examples/research_projects/adversarial/run_hans.py
index 31acbd3a8a6f..0576471fbc50 100644
--- a/examples/research_projects/adversarial/run_hans.py
+++ b/examples/research_projects/adversarial/run_hans.py
@@ -77,8 +77,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -110,7 +112,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/research_projects/adversarial/utils_hans.py b/examples/research_projects/adversarial/utils_hans.py
index b02bf8135277..e54792ad2f82 100644
--- a/examples/research_projects/adversarial/utils_hans.py
+++ b/examples/research_projects/adversarial/utils_hans.py
@@ -197,7 +197,7 @@ def __init__(
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
def gen():
- for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
+ for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
@@ -268,7 +268,7 @@ def get_labels(self):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
- for (i, line) in enumerate(lines):
+ for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
@@ -303,7 +303,7 @@ def hans_convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
+ for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d" % (ex_index))
diff --git a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
index 006ff98c950f..5e17352dc19b 100644
--- a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
+++ b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
@@ -84,7 +84,10 @@ def reset_stats(self):
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
- message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ message = (
+ f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
+ f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ )
print(message)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
diff --git a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
index ff5c2b51e8b3..b32f47d0c300 100644
--- a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
+++ b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
@@ -89,7 +89,10 @@ def reset_stats(self):
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
- message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ message = (
+ f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
+ f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ )
print(message)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
diff --git a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
index def4dff77664..d4121655e823 100755
--- a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
+++ b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
@@ -483,8 +483,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -574,8 +576,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument(
"--local_rank",
diff --git a/examples/research_projects/bertabs/run_summarization.py b/examples/research_projects/bertabs/run_summarization.py
index 33be67233ff6..fcfae6b8c6c7 100644
--- a/examples/research_projects/bertabs/run_summarization.py
+++ b/examples/research_projects/bertabs/run_summarization.py
@@ -325,7 +325,8 @@ def main():
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
- "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
+ "We could not find the directory you specified for the documents to summarize, or it was empty. Please"
+ " specify a valid path."
)
os.makedirs(args.summaries_output_dir, exist_ok=True)
diff --git a/examples/research_projects/bertology/run_bertology.py b/examples/research_projects/bertology/run_bertology.py
index 1018359dc62e..030573d87f35 100644
--- a/examples/research_projects/bertology/run_bertology.py
+++ b/examples/research_projects/bertology/run_bertology.py
@@ -338,8 +338,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. \n"
- "Sequences longer than this will be truncated, sequences shorter padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. \n"
+ "Sequences longer than this will be truncated, sequences shorter padded."
+ ),
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
diff --git a/examples/research_projects/bertology/run_prune_gpt.py b/examples/research_projects/bertology/run_prune_gpt.py
index 49a867b96dd4..68cece6e997a 100644
--- a/examples/research_projects/bertology/run_prune_gpt.py
+++ b/examples/research_projects/bertology/run_prune_gpt.py
@@ -314,8 +314,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. \n"
- "Sequences longer than this will be truncated, sequences shorter padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. \n"
+ "Sequences longer than this will be truncated, sequences shorter padded."
+ ),
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
diff --git a/examples/research_projects/codeparrot/README.md b/examples/research_projects/codeparrot/README.md
index 2b51b3ba4b57..ef92606c545a 100644
--- a/examples/research_projects/codeparrot/README.md
+++ b/examples/research_projects/codeparrot/README.md
@@ -37,30 +37,46 @@ Additionally, sure you have git-lfs installed. You can find instructions for how
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
### Preprocessing
-The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374):
+The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
-- exact deduplication using each file's hash
+- exact deduplication using each file's hash after having removed whistespaces.
+- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description.
- filtering files with max line length > 1000
- filtering files with mean line length > 100
- fraction of alphanumeric characters < 0.25
- containing the word "auto-generated" or similar in the first 5 lines
+- filtering with a probability of 0.7 of files with a mention of "test file" or "configuration file" or similar in the first 5 lines
+- filtering with a probability of 0.7 of files with high occurence of the keywords "test " or "config"
+- filtering with a probability of 0.7 of files without a mention of the keywords `def` , `for`, `while` and `class`
+- filtering files that use the assignment operator `=` less than 5 times
+- filtering files with ratio between number of characters and number of tokens after tokenization < 1.5 (the average ratio is 3.6)
-The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/lvwerra/codeparrot-clean-train) and [validation](https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid) splits are also available on the Hub if you want to skip this step or use the data for another project.
+The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/codeparrot/codeparrot-clean-train-v2) and [validation](https://huggingface.co/datasets/codeparrot/codeparrot-clean-valid-v2) splits are also available on the Hub if you want to skip this step or use the data for another project.
To execute the preprocessing run the following command:
```bash
python scripts/preprocessing.py \
---dataset_name lvwerra/codeparrot \
+--dataset_name transformersbook/codeparrot \
--output_dir codeparrot-clean
```
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
+### Pretokenization
+The tokenization of the data might be slow during the training especially for small models. We provide code to pretokenize the data beforehand in `scripts/pretokenizing.py`, but this step is optional. The dataset is downloaded and stored locally and the tokenized data is pushed to the hub. The tokenized clean [train](https://huggingface.co/datasets/codeparrot/tokenized-codeparrot-train) and [validation](https://huggingface.co/datasets/codeparrot/tokenized-codeparrot-valid) datasets are available if you want to use them directly.
+
+To execute the pretokenization, for the clean train data for instance, run the following command:
+```bash
+python scripts/pretokenizing.py \
+--dataset_name codeparrot/codeparrot-clean-train \
+--tokenized_data_repo tokenized-codeparrot-train
+```
+
## Tokenizer
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
```bash
python scripts/bpe_training.py \
--base_tokenizer gpt2 \
- --dataset_name lvwerra/codeparrot-clean-train
+ --dataset_name codeparrot/codeparrot-clean-train
```
_Note:_ We originally trained the tokenizer on the unprocessed train split of the dataset `transformersbook/codeparrot-train`.
@@ -71,13 +87,14 @@ The models are randomly initialized and trained from scratch. To initialize a ne
```bash
python scripts/initialize_model.py \
--config_name gpt2-large \
---tokenizer_name lvwerra/codeparrot \
+--tokenizer_name codeparrot/codeparrot \
--model_name codeparrot \
--push_to_hub True
```
-This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.
+This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the hub.
-Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
+We can either pass the name of a text dataset or a pretokenized dataset which speeds up training a bit.
+Now that the tokenizer and model are also ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/codeparrot/codeparrot-small/) and [1.5B](https://huggingface.co/codeparrot/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
First you need to configure `accelerate` and login to Weights & Biases:
@@ -89,14 +106,14 @@ wandb login
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
```bash
-python scripts/codeparrot_training.py
+accelerate launch scripts/codeparrot_training.py
```
If you want to train the small model you need to make some modifications:
```bash
accelerate launch scripts/codeparrot_training.py \
---model_ckpt lvwerra/codeparrot-small \
+--model_ckpt codeparrot/codeparrot-small \
--train_batch_size 12 \
--valid_batch_size 12 \
--learning_rate 5e-4 \
@@ -118,15 +135,15 @@ Instead of streaming the dataset from the hub you can also stream it from disk.
```bash
git lfs install
mkdir data
-git -C "./data" clone https://huggingface.co/datasets/lvwerra/codeparrot-clean-train
-git -C "./data" clone https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid
+git -C "./data" clone https://huggingface.co/datasets/codeparrot/codeparrot-clean-train
+git -C "./data" clone https://huggingface.co/datasets/codeparrot/codeparrot-clean-valid
```
And then pass the paths to the datasets when we run the training script:
```bash
accelerate launch scripts/codeparrot_training.py \
---model_ckpt lvwerra/codeparrot-small \
+--model_ckpt codeparrot/codeparrot-small \
--dataset_name_train ./data/codeparrot-clean-train \
--dataset_name_valid ./data/codeparrot-clean-valid \
--train_batch_size 12 \
@@ -143,13 +160,13 @@ accelerate launch scripts/codeparrot_training.py \
For evaluating the language modeling loss on the validation set or any other dataset you can use the following command:
```bash
python scripts/validation_loss.py \
---model_ckpt lvwerra/codeparrot \
---dataset_name lvwerra/codeparrot-clean-valid
+--model_ckpt codeparrot/codeparrot \
+--dataset_name codeparrot/codeparrot-clean-valid
```
In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. You can run the evaluation with the following command:
```bash
-python scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
+accelerate launch scripts/human_eval.py --model_ckpt codeparrot/codeparrot \
--do_sample True \
--temperature 0.2 \
--top_p 0.95 \
@@ -162,7 +179,7 @@ The results as well as reference values are shown in the following table:
| Model | pass@1 | pass@10 | pass@100|
|-------|--------|---------|---------|
|CodeParrot š¦ (110M) | 3.80% | 6.57% | 12.78% |
-|CodeParrot š¦ (1.5B) | 3.58% | 8.03% | 14.96% |
+|CodeParrot š¦ (1.5B) | 3.99% | 8.69% | 17.88% |
|||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|Codex (85M)| 8.22% | 12.81% | 22.40% |
@@ -177,9 +194,117 @@ The results as well as reference values are shown in the following table:
The numbers were obtained by sampling with `T = [0.2, 0.6, 0.8]` and picking the best value for each metric. Both CodeParrot š¦ models are still underfitted and longer training would likely improve the performance.
## Demo
-Give the model a shot yourself! There are two demos to interact with CodeParrot š¦:
-- [Code generation](https://huggingface.co/spaces/lvwerra/codeparrot-generation)
-- [Code highlighting](https://huggingface.co/spaces/lvwerra/codeparrot-highlighting)
+Give the model a shot yourself! There are three demos to interact with CodeParrot š¦:
+- [Code generation](https://huggingface.co/spaces/codeparrot/codeparrot-generation)
+- [Code highlighting](https://huggingface.co/spaces/codeparrot/codeparrot-highlighting)
+- [Comparison to other code models](https://huggingface.co/spaces/codeparrot/loubnabnl/code-generation-models)
+
+## Training with Megatron
+[Megatron](https://github.com/NVIDIA/Megatron-LM) is a framework developed by NVIDIA for training large transformer models. While the CodeParrot code is easy to follow and modify to your needs the Megatron framework lets you train models faster. Below we explain how to use it.
+
+### Setup
+You can pull an NVIDIA PyTorch Container that comes with all the required installations from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). See [documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for more details:
+
+With the following Docker command you can run the container (`xx.xx` denotes your Docker version), and clone [Megatron repository](https://github.com/NVIDIA/Megatron-LM) into it:
+```bash
+docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:xx.xx-py3
+git clone https://github.com/NVIDIA/Megatron-LM
+```
+
+You also need to add the vocabulary file and merges table of the tokenizer that you trained on code into the container. You can also find these files in [vocab.json](https://huggingface.co/codeparrot/codeparrot/raw/main/vocab.json) and [merges.txt](https://huggingface.co/codeparrot/codeparrot/raw/main/merges.txt).
+```bash
+sudo docker cp vocab.json CONTAINER_ID:/workspace/Megatron-LM
+sudo docker cp merges.txt CONTAINER_ID:/workspace/Megatron-LM
+```
+
+### Data preprocessing
+The training data requires preprocessing. First, you need to convert it into a loose json format, with one json containing a text sample per line. In python this can be done this way:
+```python
+from datasets import load_dataset
+
+train_data = load_dataset('codeparrot/codeparrot-clean-train', split='train')
+train_data.to_json("codeparrot_data.json", lines=True)
+```
+
+The data is then tokenized, shuffled and processed into a binary format for training using the following command:
+```bash
+pip install nltk
+cd Megatron-LM
+python tools/preprocess_data.py \
+ --input codeparrot_data.json \
+ --output-prefix codeparrot \
+ --vocab vocab.json \
+ --dataset-impl mmap \
+ --tokenizer-type GPT2BPETokenizer \
+ --merge-file merges.txt \
+ --json-keys content \
+ --workers 32 \
+ --chunk-size 25 \
+ --append-eod
+```
+This outputs two files `codeparrot_content_document.idx` and `codeparrot_content_document.bin` which are used in the training.
+
+### Training
+You can configure the model architecture and training parameters as shown below, or put it in a bash script that you will run. This runs on 8 GPUs the 110M parameter CodeParrot pretraining, with the same settings as before. Note that the data is partitioned by default into a 969:30:1 ratio for training/validation/test sets.
+```bash
+GPUS_PER_NODE=8
+MASTER_ADDR=localhost
+MASTER_PORT=6001
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+CHECKPOINT_PATH=/workspace/Megatron-LM/experiments/codeparrot-small
+VOCAB_FILE=vocab.json
+MERGE_FILE=merges.txt
+DATA_PATH=codeparrot_content_document
+GPT_ARGS="--num-layers 12
+--hidden-size 768
+--num-attention-heads 12
+--seq-length 1024
+--max-position-embeddings 1024
+--micro-batch-size 12
+--global-batch-size 192
+--lr 0.0005
+--train-iters 150000
+--lr-decay-iters 150000
+--lr-decay-style cosine
+--lr-warmup-iters 2000
+--weight-decay .1
+--adam-beta2 .999
+--fp16
+--log-interval 10
+--save-interval 2000
+--eval-interval 200
+--eval-iters 10
+"
+TENSORBOARD_ARGS="--tensorboard-dir experiments/tensorboard"
+python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \
+ pretrain_gpt.py \
+ --tensor-model-parallel-size 1 \
+ --pipeline-model-parallel-size 1 \
+ $GPT_ARGS \
+ --vocab-file $VOCAB_FILE \
+ --merge-file $MERGE_FILE \
+ --save $CHECKPOINT_PATH \
+ --load $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ $TENSORBOARD_ARGS
+```
+The training takes almost 12 hours in this setting.
+
+### Convert model to `transformers`
+After training we want to use the model in `transformers` e.g. to evaluate it on HumanEval. You can convert it to `transformers` following [this](https://huggingface.co/nvidia/megatron-gpt2-345m) tutorial. For instance, after the training is finished you can copy the weights of the last iteration 150k and convert the `model_optim_rng.pt` file to a `pytorch_model.bin` file that is supported by `transformers`.
+
+```bash
+mkdir -p nvidia/megatron-codeparrot-small
+sudo docker cp CONTAINER_ID:/workspace/Megatron-LM/experiments/codeparrot-small/iter_0150000/mp_rank_00/model_optim_rng.pt nvidia/megatron-codeparrot-small
+git clone https://github.com/huggingface/transformers.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+export PYTHONPATH=Megatron-LM
+python transformers/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py nvidia/megatron-codeparrot-small/model_optim_rng.pt
+```
+Be careful, you will need to replace the generated vocabulary file and merges table after the conversion, with the original ones if you plan to load the tokenizer from there.
## Further Resources
A detailed description of the project can be found in the chapter "Training Transformers from Scratch" in the upcoming O'Reilly book [Natural Language Processing with Transformers](https://learning.oreilly.com/library/view/natural-language-processing/9781098103231/).
diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt
index a8aadb4ed973..7eff3ac7f135 100644
--- a/examples/research_projects/codeparrot/requirements.txt
+++ b/examples/research_projects/codeparrot/requirements.txt
@@ -1,7 +1,9 @@
-transformers==4.15.0
+transformers==4.19.0
datasets==1.16.0
-accelerate==0.6.2
wandb==0.12.0
tensorboard==2.6.0
-torch==1.9.0
-huggingface-hub==0.1.0
\ No newline at end of file
+torch==1.11.0
+huggingface-hub==0.1.0
+git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
+datasketch==1.5.7
+dpu_utils
\ No newline at end of file
diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py
index a94cda2d2f1b..4def9ac3b854 100644
--- a/examples/research_projects/codeparrot/scripts/arguments.py
+++ b/examples/research_projects/codeparrot/scripts/arguments.py
@@ -9,24 +9,22 @@ class TrainingArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be trained."},
+ default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be trained."}
)
save_dir: Optional[str] = field(
- default="./",
- metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
+ default="./", metadata={"help": "Save dir where model repo is cloned and models updates are saved to."}
)
dataset_name_train: Optional[str] = field(
- default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
+ default="codeparrot/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
)
dataset_name_valid: Optional[str] = field(
- default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
+ default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
)
train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."})
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
shuffle_buffer: Optional[int] = field(
- default=1000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
+ default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
)
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
@@ -39,7 +37,7 @@ class TrainingArguments:
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
)
- max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
+ max_train_steps: Optional[int] = field(default=50000, metadata={"help": "Maximum number of training steps."})
max_eval_steps: Optional[int] = field(
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
)
@@ -50,9 +48,9 @@ class TrainingArguments:
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
resume_from_checkpoint: Optional[str] = field(
- default=None,
- metadata={"help": "States path if the training should continue from a checkpoint folder."},
+ default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
)
+ tokenized: Optional[bool] = field(default=False, metadata={"help": "If True the data is pretokenized."})
@dataclass
@@ -62,11 +60,10 @@ class EvaluationArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be evaluated."},
+ default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
dataset_name: Optional[str] = field(
- default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
+ default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
)
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size used for evaluation."})
max_eval_steps: Optional[int] = field(
@@ -83,8 +80,7 @@ class HumanEvalArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be evaluated."},
+ default="codeparrot/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
num_tasks: Optional[int] = field(
@@ -112,7 +108,10 @@ class HumanEvalArguments:
device_int: Optional[int] = field(
default=-1,
metadata={
- "help": "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive number corresponds to which GPU device id to run on."
+ "help": (
+ "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive"
+ " number corresponds to which GPU device id to run on."
+ )
},
)
@@ -130,7 +129,7 @@ class PreprocessingArguments:
},
)
dataset_name: Optional[str] = field(
- default="codeparrot", metadata={"help": "Folder or name of dataset to process."}
+ default="transformersbook/codeparrot", metadata={"help": "Folder or name of dataset to process."}
)
output_dir: Optional[str] = field(
default="codeparrot-clean", metadata={"help": "Folder to save processed processed dataset."}
@@ -148,6 +147,22 @@ class PreprocessingArguments:
alpha_frac: Optional[float] = field(
default=0.25, metadata={"help": "Maximum fraction of non-alphanumeric characters, otherwise file is filtered."}
)
+ min_token_ratio: Optional[float] = field(
+ default=1.5, metadata={"help": "Minimum character token ratio for the file, otherwise file is filtered."}
+ )
+ filter_proba: Optional[float] = field(
+ default=0.7, metadata={"help": "Probability for filtering config, test and uncommon files."}
+ )
+ tokenizer: Optional[str] = field(
+ default="codeparrot/codeparrot",
+ metadata={"help": "Name or path to the tokenizer."},
+ )
+ near_deduplication: Optional[bool] = field(
+ default=False, metadata={"help": "If True, near-duplicate samples are removed."}
+ )
+ jaccard_threshold: Optional[float] = field(
+ default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."}
+ )
@dataclass
@@ -157,14 +172,13 @@ class TokenizerTrainingArguments:
"""
base_tokenizer: Optional[str] = field(
- default="gpt2",
- metadata={"help": "Base tokenizer to build new tokenizer from."},
+ default="gpt2", metadata={"help": "Base tokenizer to build new tokenizer from."}
)
dataset_name: Optional[str] = field(
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
)
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
- vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
+ vocab_size: Optional[int] = field(default=200_000, metadata={"help": "Number of examples to train tokenizer on."})
n_examples: Optional[int] = field(
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
)
@@ -172,6 +186,24 @@ class TokenizerTrainingArguments:
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
+@dataclass
+class PretokenizationArguments:
+ """
+ Configuration for data pretokenization.
+ """
+
+ tokenizer_dir: Optional[str] = field(
+ default="codeparrot/codeparrot", metadata={"help": "Name or path to the tokenizer."}
+ )
+ dataset_name: Optional[str] = field(
+ default="codeparrot/codeparrot-clean-train", metadata={"help": "Name or path to the dataset to pretokenize."}
+ )
+ tokenized_data_repo: Optional[str] = field(
+ default="tokenized-codeparrot-train", metadata={"help": "Repo name of the pretokenized data."}
+ )
+ num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
+
+
@dataclass
class InitializationArguments:
"""
@@ -179,11 +211,10 @@ class InitializationArguments:
"""
config_name: Optional[str] = field(
- default="gpt2-large",
- metadata={"help": "Configuration to use for model initialization."},
+ default="gpt2-large", metadata={"help": "Configuration to use for model initialization."}
)
tokenizer_name: Optional[str] = field(
- default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}
+ default="codeparrot/codeparrot", metadata={"help": "Tokenizer attached to model."}
)
model_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of the created model."})
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py
index b00afac7508f..b2af8767a217 100644
--- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py
+++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py
@@ -7,14 +7,16 @@
import datasets
import torch
from datasets import load_dataset
+from torch.optim import AdamW
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
import transformers
from accelerate import Accelerator, DistributedType
from arguments import TrainingArguments
from huggingface_hub import Repository
-from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
+from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
class ConstantLengthDataset(IterableDataset):
@@ -25,21 +27,36 @@ class ConstantLengthDataset(IterableDataset):
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
- num_of_sequences: Number of token sequences to keep in buffer.
- chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
+ num_of_sequences (int): Number of token sequences to keep in buffer.
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
+ tokenized (bool): If true we use a pretokenized dataset.
"""
def __init__(
- self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
+ self,
+ tokenizer,
+ dataset,
+ infinite=False,
+ seq_length=1024,
+ num_of_sequences=1024,
+ chars_per_token=3.6,
+ tokenized=False,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
- self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite
self.current_size = 0
+ self.tokenized = tokenized
+
+ if self.tokenized:
+ self.max_buffer_size = seq_length * num_of_sequences
+ self.content_field = "input_ids"
+ else:
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
+ self.content_field = "content"
def __iter__(self):
iterator = iter(self.dataset)
@@ -47,10 +64,10 @@ def __iter__(self):
while more_examples:
buffer, buffer_len = [], 0
while True:
- if buffer_len >= self.input_characters:
+ if buffer_len >= self.max_buffer_size:
break
try:
- buffer.append(next(iterator)["content"])
+ buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
@@ -60,7 +77,10 @@ def __iter__(self):
else:
more_examples = False
break
- tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
+ if self.tokenized:
+ tokenized_inputs = buffer
+ else:
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
@@ -70,6 +90,9 @@ def __iter__(self):
self.current_size += 1
yield torch.tensor(input_ids)
+ def shuffle(self, buffer_size=1000):
+ return ShufflerIterDataPipe(self, buffer_size=buffer_size)
+
def setup_logging(args):
project_name = args.model_ckpt.split("/")[-1]
@@ -102,14 +125,19 @@ def create_dataloaders(args):
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
- train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
- valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
- train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
+ train_dataset = ConstantLengthDataset(
+ tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
+ )
+ valid_dataset = ConstantLengthDataset(
+ tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
+ )
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer)
+ train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader
-def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
+def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
@@ -162,14 +190,14 @@ def evaluate(args):
return loss.item(), perplexity.item()
-# Accelerator
-accelerator = Accelerator(log_with=["wandb", "tensorboard"])
-acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
-
# Settings
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args()
+# Accelerator
+accelerator = Accelerator(log_with=["wandb", "tensorboard"], logging_dir=f"{args.save_dir}/log")
+acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
+
args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
set_seed(args.seed)
@@ -234,13 +262,14 @@ def get_lr():
model.train()
completed_steps = 0
t_start = time.time()
+loss_tracking = 0
for step, batch in enumerate(train_dataloader, start=1):
if args.resume_from_checkpoint and step < resume_step:
continue # we need to skip steps until we reach the resumed step
loss = model(batch, labels=batch, use_cache=False).loss
- log_metrics(
- step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
- )
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ loss_tracking += avg_loss.item() / args.gradient_accumulation_steps
+ log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
loss = loss / args.gradient_accumulation_steps
if step % args.gradient_accumulation_steps != 0:
# Prevent backward from doing gradient all_reduce in every step
@@ -250,16 +279,27 @@ def get_lr():
else:
accelerator.backward(loss)
else:
+ lr = get_lr()
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- completed_steps += 1
elapsed_time = time.time() - t_start
tflops = compute_tflops(elapsed_time, accelerator, args)
- log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
+ log_metrics(
+ step,
+ {
+ "steps": completed_steps,
+ "loss/train": loss_tracking,
+ "lr": lr,
+ "tflops": tflops,
+ "time_per_iteration": elapsed_time,
+ },
+ )
t_start = time.time()
+ loss_tracking = 0
+ completed_steps += 1
if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args)
diff --git a/examples/research_projects/codeparrot/scripts/human_eval.py b/examples/research_projects/codeparrot/scripts/human_eval.py
index 1eb5555cd79c..d0614134ad47 100644
--- a/examples/research_projects/codeparrot/scripts/human_eval.py
+++ b/examples/research_projects/codeparrot/scripts/human_eval.py
@@ -186,7 +186,8 @@ def main():
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
except ValueError as exception:
print(
- 'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.'
+ 'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"`'
+ " flag to enable code evaluation."
)
raise exception
diff --git a/examples/research_projects/codeparrot/scripts/initialize_model.py b/examples/research_projects/codeparrot/scripts/initialize_model.py
index 8654ccc90622..9d066b190873 100644
--- a/examples/research_projects/codeparrot/scripts/initialize_model.py
+++ b/examples/research_projects/codeparrot/scripts/initialize_model.py
@@ -10,13 +10,17 @@
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
-config_kwargs = {"vocab_size": len(tokenizer), "scale_attn_by_layer_idx": True, "reorder_and_upcast_attn": True}
+config_kwargs = {
+ "vocab_size": len(tokenizer),
+ "scale_attn_by_inverse_layer_idx": True,
+ "reorder_and_upcast_attn": True,
+}
# Load model config (GPT-2 large in this case)
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
# Initialize new model with config
-model = AutoModelForCausalLM(config)
+model = AutoModelForCausalLM.from_config(config)
# Save model to the hub
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
diff --git a/examples/research_projects/codeparrot/scripts/minhash_deduplication.py b/examples/research_projects/codeparrot/scripts/minhash_deduplication.py
new file mode 100644
index 000000000000..cd72dcb70c9e
--- /dev/null
+++ b/examples/research_projects/codeparrot/scripts/minhash_deduplication.py
@@ -0,0 +1,270 @@
+import json
+import multiprocessing as mp
+import re
+from collections import defaultdict
+from functools import partial
+from typing import Dict, List, Optional, Set, Tuple, Type
+
+from datasets import Dataset
+from tqdm import tqdm
+
+from datasketch import MinHash, MinHashLSH
+from dpu_utils.utils.iterators import ThreadedIterator
+
+
+NON_ALPHA = re.compile("[^A-Za-z_0-9]")
+# parameters used in DuplicationIndex
+MIN_NUM_TOKENS = 10
+NUM_PERM = 256
+
+
+def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
+ """Compute the MinHash of a code snippet."""
+ if len(tokens) < MIN_NUM_TOKENS:
+ return None
+ min_hash = MinHash(num_perm=NUM_PERM)
+ for token in set(tokens):
+ min_hash.update(token.encode())
+ return min_hash
+
+
+def get_tokens(code: str) -> Set[str]:
+ """Tokenize a code snippet."""
+ return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0])
+
+
+class DuplicationIndex:
+ def __init__(
+ self,
+ *,
+ duplication_jaccard_threshold: float = 0.85,
+ ):
+ self._duplication_jaccard_threshold = duplication_jaccard_threshold
+ self._num_perm = NUM_PERM
+ self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
+
+ self._duplicate_clusters = defaultdict(set)
+
+ def add(self, code_key: Tuple, min_hash: MinHash) -> None:
+ """Add a key to _index (MinHashLSH)
+ the min_hash is used to query closest matches based on the jaccard_threshold.
+ The new key is either added to a existing cluster of one close match,
+ or a new cluster is created. The clusters created in this way, depend on the order of add.
+
+ Args:
+ code_key (Tuple of (index, repo_name, path)):
+ Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
+ min_hash: MinHash of the code_key.
+ """
+ close_duplicates = self._index.query(min_hash)
+ if code_key in self._index.keys:
+ print(f"Duplicate key {code_key}")
+ return
+
+ self._index.insert(code_key, min_hash)
+ if len(close_duplicates) > 0:
+
+ for base_duplicate in close_duplicates:
+ if base_duplicate in self._duplicate_clusters:
+ self._duplicate_clusters[base_duplicate].add(code_key)
+ break
+ else:
+ self._duplicate_clusters[close_duplicates[0]].add(code_key)
+
+ def get_duplicate_clusters(self) -> List[List[Dict]]:
+ """Export the duplicate clusters.
+ For each cluster, the first element is the base element of the cluster.
+ The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
+
+ Returns:
+ duplicate_clusters (List[List[Dict]]):
+ List of duplicate clusters.
+ """
+ duplicate_clusters = []
+ for base, duplicates in self._duplicate_clusters.items():
+ cluster = [base] + list(duplicates)
+ # reformat the cluster to be a list of dict
+ cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
+ duplicate_clusters.append(cluster)
+ return duplicate_clusters
+
+ def save(self, filepath) -> None:
+ duplicate_clusters = self.get_duplicate_clusters()
+ with open(filepath, "w") as f:
+ json.dump(duplicate_clusters, f)
+
+
+def _compute_min_hash(element):
+ index, data = element
+ min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
+ if min_hash is not None:
+ return (index, data["repo_name"], data["path"]), min_hash
+
+
+def minhash_iter(dataset_iterator: Type[Dataset]):
+ with mp.Pool() as pool:
+ for data in pool.imap_unordered(
+ _compute_min_hash,
+ ThreadedIterator(dataset_iterator, max_queue_size=10000),
+ chunksize=100,
+ ):
+ if data is not None:
+ yield data
+
+
+def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
+ """Find duplicate clusters in the dataset in two steps:
+ 1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
+ This step is computed using an asynchronous multiprocessing pool, minhash_iter
+ 2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
+ This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
+ """
+ di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
+
+ for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
+ di.add(filename, min_hash)
+
+ # Returns a List[Cluster] where Cluster is List[str] with the filenames.
+ return di.get_duplicate_clusters()
+
+
+def jaccard_similarity(code1: str, code2: str) -> float:
+ """Compute the Jaccard similarity of two code snippets."""
+ tokens1 = get_tokens(code1)
+ tokens2 = get_tokens(code2)
+ return len(tokens1 & tokens2) / len(tokens1 | tokens2)
+
+
+_shared_dataset = None
+
+
+def _find_cluster_extremes_shared(cluster, jaccard_threshold):
+ """Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
+ Two codes are similar if their Jaccard similarity is above the threshold.
+
+ Args:
+ cluster (List[dict]):
+ cluster is a list of dict, each dict contains the following keys:
+ - base_index
+ - repo_name
+ - path
+ This is a typical output of DuplicationIndex.get_duplicate_clusters()
+ jaccard_threshold (float):
+ threshold for Jaccard similarity.
+ Two codes are similar if their Jaccard similarity is above the threshold.
+
+ Returns:
+ extremes (List[dict]):
+ A reduced representation of the cluster. The field copies is added to each dict.
+ The copies field indicates the number of similar codes in the cluster for a extreme.
+ """
+ extremes = []
+ for element1 in cluster:
+ code1 = _shared_dataset[element1["base_index"]]["content"]
+ for element2 in extremes:
+ code2 = _shared_dataset[element2["base_index"]]["content"]
+ if jaccard_similarity(code1, code2) >= jaccard_threshold:
+ element2["copies"] += 1
+ break
+ else:
+ element1["copies"] = 1
+ extremes.append(element1)
+ return extremes
+
+
+def find_extremes(cluster_list, dataset, jaccard_threshold):
+ """Call the _find_cluster_extremes_shared function in a parallel fashion.
+
+ Args:
+ cluster_list (List[List[Dict]]):
+ each cluster is a list of dicts with the key base_index,
+ referring to the index of the base code in the dataset.
+ dataset (Type[Dataset]):
+ dataset is used to access the content of the code snippets,
+ using the base_index from the cluster_list.
+ dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
+ otherwise the multi processing is not speeded up.
+ jaccard_threshold (float):
+ the threshold for the jaccard similarity. The default value is 0.85
+
+ Returns:
+ extremes_list (List[Dict]):
+ Each cluster is reduced to extremes.
+ See _find_cluster_extremes_shared for the definition of extremes.
+ """
+ global _shared_dataset
+ _shared_dataset = dataset
+ extremes_list = []
+ f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
+ with mp.Pool() as pool:
+ for extremes in tqdm(
+ pool.imap_unordered(
+ f,
+ cluster_list,
+ ),
+ total=len(cluster_list),
+ ):
+ extremes_list.append(extremes)
+ return extremes_list
+
+
+def deduplicate_dataset(
+ dataset: Type[Dataset], jaccard_threshold: float = 0.85
+) -> Tuple[Type[Dataset], List[List[Dict]]]:
+ """Deduplicate the dataset using minhash and jaccard similarity.
+ This function first generate duplicate clusters, then each cluster
+ is reduced to the extremes that are similar to the other elements in the cluster.
+ Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
+
+ Args:
+ dataset (Type[Dataset]):
+ The dataset to deduplicate.
+ jaccard_threshold (float, default=0.85):
+ jaccard threshold to determine if two codes are similar
+
+ Returns:
+ ds_dedup (Type[Dataset]):
+ The deduplicated dataset.
+ duplicate_clusters (List[List[Dict]]):
+ The list of duplicate clusters.
+ Each cluster is a list of dicts with the following keys:
+ - base_index : int
+ The index of the code in the original dataset.
+ - repo_name : str
+ - path : str
+ - copies : int
+ The number of copies of the code in the cluster. (find_cluster_extremes)
+ - is_extreme : bool
+ Whether the code is an extreme in the cluster.
+ All the codes in the cluster are removed from the dataset except the extremes.
+
+ Example:
+ >>> from datasets import load_dataset
+ >>> from minhash_deduplication import deduplicate_dataset
+ >>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
+ >>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
+ """
+ duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
+ duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster)
+ extreme_dict = {}
+ extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
+ for extremes in extremes_clusters:
+ for element in extremes:
+ extreme_dict[element["base_index"]] = element
+ remove_indices = duplicate_indices - set(extreme_dict.keys())
+ ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
+
+ # update duplicate_clusters
+ for cluster in duplicate_clusters:
+ for element in cluster:
+ element["is_extreme"] = element["base_index"] in extreme_dict
+ if element["is_extreme"]:
+ element["copies"] = extreme_dict[element["base_index"]]["copies"]
+
+ print(f"Original dataset size: {len(dataset)}")
+ print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
+ print(f"Files in duplicate cluster: {len(duplicate_indices)}")
+ print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
+ print(f"Filtered dataset size: {len(ds_filter)}")
+
+ return ds_filter, duplicate_clusters
diff --git a/examples/research_projects/codeparrot/scripts/preprocessing.py b/examples/research_projects/codeparrot/scripts/preprocessing.py
index 4e09379a943f..6236a8aad86a 100644
--- a/examples/research_projects/codeparrot/scripts/preprocessing.py
+++ b/examples/research_projects/codeparrot/scripts/preprocessing.py
@@ -1,20 +1,27 @@
import gzip
import hashlib
+import json
import multiprocessing
import os
+import re
import shutil
import time
+from pathlib import Path
import numpy as np
from datasets import load_dataset
from arguments import PreprocessingArguments
-from transformers import HfArgumentParser
+from minhash_deduplication import deduplicate_dataset
+from transformers import AutoTokenizer, HfArgumentParser
+
+
+PATTERN = re.compile(r"\s+")
def get_hash(example):
"""Get hash of content field."""
- return {"hash": hashlib.md5(example["content"].strip().encode("utf-8")).hexdigest()}
+ return {"hash": hashlib.md5(re.sub(PATTERN, "", example["content"]).encode("utf-8")).hexdigest()}
def line_stats(example):
@@ -50,18 +57,77 @@ def is_autogenerated(example, scan_width=5):
return {"autogenerated": False}
+def is_config_or_test(example, scan_width=5, coeff=0.05):
+ """Check if file is a configuration file or a unit test by :
+ 1- looking for keywords in the first few lines of the file.
+ 2- counting number of occurence of the words 'config' and 'test' with respect to number of lines.
+ """
+
+ keywords = ["unit tests", "test file", "configuration file"]
+ lines = example["content"].splitlines()
+ count_config = 0
+ count_test = 0
+ # first test
+ for _, line in zip(range(scan_width), lines):
+ for keyword in keywords:
+ if keyword in line.lower():
+ return {"config_or_test": True}
+ # second test
+ nlines = example["content"].count("\n")
+ threshold = int(coeff * nlines)
+ for line in lines:
+ count_config += line.lower().count("config")
+ count_test += line.lower().count("test")
+ if count_config > threshold or count_test > threshold:
+ return {"config_or_test": True}
+ return {"config_or_test": False}
+
+
+def has_no_keywords(example):
+ """Check if a python file has none of the keywords for: funcion, class, for loop, while loop."""
+ keywords = ["def ", "class ", "for ", "while "]
+ lines = example["content"].splitlines()
+ for line in lines:
+ for keyword in keywords:
+ if keyword in line.lower():
+ return {"has_no_keywords": False}
+ return {"has_no_keywords": True}
+
+
+def has_few_assignments(example, minimum=4):
+ """Check if file uses symbol '=' less than `minimum` times."""
+ lines = example["content"].splitlines()
+ counter = 0
+ for line in lines:
+ counter += line.lower().count("=")
+ if counter > minimum:
+ return {"has_few_assignments": False}
+ return {"has_few_assignments": True}
+
+
+def char_token_ratio(example):
+ """Compute character/token ratio of the file with tokenizer."""
+ input_ids = tokenizer(example["content"], truncation=False)["input_ids"]
+ ratio = len(example["content"]) / len(input_ids)
+ return {"ratio": ratio}
+
+
def preprocess(example):
"""Chain all preprocessing steps into one function to not fill cache."""
results = dict()
results.update(get_hash(example))
results.update(line_stats(example))
results.update(alpha_stats(example))
+ results.update(char_token_ratio(example))
results.update(is_autogenerated(example))
+ results.update(is_config_or_test(example))
+ results.update(has_no_keywords(example))
+ results.update(has_few_assignments(example))
return results
def filter(example, uniques, args):
- """Filter dataset with heuristics."""
+ """Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability."""
if not check_uniques(example, uniques):
return False
elif example["autogenerated"]:
@@ -72,6 +138,14 @@ def filter(example, uniques, args):
return False
elif example["alpha_frac"] < args.alpha_frac:
return False
+ elif example["ratio"] < args.min_token_ratio:
+ return False
+ elif example["config_or_test"] and np.random.rand() <= args.filter_proba:
+ return False
+ elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba:
+ return False
+ elif example["has_few_assignments"]:
+ return False
else:
return True
@@ -79,7 +153,7 @@ def filter(example, uniques, args):
def compress_file(file_path):
"""Compress a file with g-zip."""
with open(file_path, "rb") as f_in:
- with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
+ with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(file_path)
@@ -89,6 +163,7 @@ def compress_file(file_path):
args = parser.parse_args()
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
+tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
# Load dataset
t_start = time.time()
@@ -111,12 +186,29 @@ def compress_file(file_path):
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
print(f"Size of filtered dataset: {len(ds_filter)}")
+# Deduplicate with minhash and jaccard similarity
+if args.near_deduplication:
+ t_start = time.time()
+ ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
+ print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
+ print(f"Size of deduplicate dataset: {len(ds_filter)}")
+
# Save data in batches of samples_per_file
-if not os.path.exists(args.output_dir):
- os.makedirs(args.output_dir)
+output_dir = Path(args.output_dir)
+output_dir.mkdir(exist_ok=True)
+
+# save duplicate_clusters in the output_dir as artifacts
+# not sure it is the right place the save it
+if args.near_deduplication:
+ with open(output_dir / "duplicate_clusters.json", "w") as f:
+ json.dump(duplicate_clusters, f)
+
+data_dir = output_dir / "data"
+data_dir.mkdir(exist_ok=True)
+
t_start = time.time()
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
- file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
+ file_path = str(data_dir / f"file-{file_number+1:012}.json")
end_index = min(len(ds_filter), index + args.samples_per_file)
ds_filter.select(list(range(index, end_index))).to_json(file_path)
compress_file(file_path)
diff --git a/examples/research_projects/codeparrot/scripts/pretokenizing.py b/examples/research_projects/codeparrot/scripts/pretokenizing.py
new file mode 100644
index 000000000000..9ebe1e577dde
--- /dev/null
+++ b/examples/research_projects/codeparrot/scripts/pretokenizing.py
@@ -0,0 +1,49 @@
+import multiprocessing
+import time
+
+from datasets import load_dataset
+
+from arguments import PretokenizationArguments
+from transformers import AutoTokenizer, HfArgumentParser
+
+
+def tokenize(example):
+ output = dict()
+ output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
+ output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
+ return output
+
+
+parser = HfArgumentParser(PretokenizationArguments)
+args = parser.parse_args()
+if args.num_workers is None:
+ args.num_workers = multiprocessing.cpu_count()
+tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
+
+t_start = time.time()
+ds = load_dataset(args.dataset_name, split="train")
+print(f"Dataset loaded in {time.time()-t_start:.2f}s")
+
+t_start = time.time()
+ds = ds.map(
+ tokenize,
+ num_proc=args.num_workers,
+ remove_columns=[
+ "repo_name",
+ "path",
+ "copies",
+ "size",
+ "content",
+ "license",
+ "hash",
+ "line_mean",
+ "line_max",
+ "alpha_frac",
+ "autogenerated",
+ ],
+)
+print(f"Dataset tokenized in {time.time()-t_start:.2f}s")
+
+t_start = time.time()
+ds.push_to_hub(args.tokenized_data_repo)
+print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")
diff --git a/examples/research_projects/codeparrot/scripts/tests/__init__.py b/examples/research_projects/codeparrot/scripts/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py b/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py
new file mode 100644
index 000000000000..e44382713557
--- /dev/null
+++ b/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py
@@ -0,0 +1,30 @@
+from unittest import TestCase
+
+from datasets import Dataset
+
+from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
+
+
+def get_dataset():
+ data_dict = {
+ "repo_name": ["test_repo1", "test_repo2", "test_repo3"],
+ "path": ["test_1.py", "test_2.py", "unit_test.py"],
+ "content": ["a " * 20, "a " * 30, "b " * 7],
+ }
+ dataset = Dataset.from_dict(data_dict)
+ return dataset
+
+
+class MakeDuplicateClustersTest(TestCase):
+ def test_make_duplicate_clusters(self):
+ ds = get_dataset()
+ duplicate_clusters = make_duplicate_clusters(ds, 0.85)
+ self.assertEqual(len(duplicate_clusters[0]), 2)
+
+ def test_deduplicate_dataset(self):
+ ds = get_dataset()
+ ds_filter, duplicate_clusters = deduplicate_dataset(ds)
+ self.assertEqual(len(ds_filter), 2)
+ print(duplicate_clusters)
+ self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
+ self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)
diff --git a/examples/research_projects/decision_transformer/requirements.txt b/examples/research_projects/decision_transformer/requirements.txt
index 4924f4b513d2..bf3dd4f1777f 100644
--- a/examples/research_projects/decision_transformer/requirements.txt
+++ b/examples/research_projects/decision_transformer/requirements.txt
@@ -33,7 +33,7 @@ cmaes==0.8.2
cmd2==2.4.0
codecarbon==1.2.0
colorlog==6.6.0
-cookiecutter==1.7.2
+cookiecutter==2.1.1
cryptography==36.0.2
csvw==2.0.0
cycler==0.11.0
@@ -205,7 +205,7 @@ tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5
-tensorflow==2.8.0
+tensorflow==2.8.1
tensorflow-io-gcs-filesystem==0.24.0
termcolor==1.1.0
text-unidecode==1.3
diff --git a/examples/research_projects/deebert/run_glue_deebert.py b/examples/research_projects/deebert/run_glue_deebert.py
index 5bfc2f8816dc..f86390375ff7 100644
--- a/examples/research_projects/deebert/run_glue_deebert.py
+++ b/examples/research_projects/deebert/run_glue_deebert.py
@@ -459,8 +459,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -529,8 +531,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
diff --git a/examples/research_projects/distillation/grouped_batch_sampler.py b/examples/research_projects/distillation/grouped_batch_sampler.py
index 6c2d9b974886..83addc371f2e 100644
--- a/examples/research_projects/distillation/grouped_batch_sampler.py
+++ b/examples/research_projects/distillation/grouped_batch_sampler.py
@@ -60,7 +60,7 @@ class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
- "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ "sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
diff --git a/examples/research_projects/distillation/run_squad_w_distillation.py b/examples/research_projects/distillation/run_squad_w_distillation.py
index ea1f2f46a969..3acfd4686406 100644
--- a/examples/research_projects/distillation/run_squad_w_distillation.py
+++ b/examples/research_projects/distillation/run_squad_w_distillation.py
@@ -518,7 +518,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -590,8 +593,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -603,8 +608,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -649,14 +656,18 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
@@ -685,8 +696,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/research_projects/distillation/scripts/extract.py b/examples/research_projects/distillation/scripts/extract.py
index d7a99b1d89d0..f60f243dece6 100644
--- a/examples/research_projects/distillation/scripts/extract.py
+++ b/examples/research_projects/distillation/scripts/extract.py
@@ -25,7 +25,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
+ description=(
+ "Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
+ " Distillation"
+ )
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default="roberta-large", type=str)
diff --git a/examples/research_projects/distillation/scripts/extract_distilbert.py b/examples/research_projects/distillation/scripts/extract_distilbert.py
index e125f36187cd..a58105f999e8 100644
--- a/examples/research_projects/distillation/scripts/extract_distilbert.py
+++ b/examples/research_projects/distillation/scripts/extract_distilbert.py
@@ -25,7 +25,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
+ description=(
+ "Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
+ " Distillation"
+ )
)
parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
diff --git a/examples/research_projects/distillation/scripts/token_counts.py b/examples/research_projects/distillation/scripts/token_counts.py
index aa223fda7035..736b564ee76e 100644
--- a/examples/research_projects/distillation/scripts/token_counts.py
+++ b/examples/research_projects/distillation/scripts/token_counts.py
@@ -43,7 +43,7 @@
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
- logger.info("Counting occurences for MLM.")
+ logger.info("Counting occurrences for MLM.")
counter = Counter()
for tk_ids in data:
counter.update(tk_ids)
diff --git a/examples/research_projects/distillation/train.py b/examples/research_projects/distillation/train.py
index 6385c885a96e..cc2362888e47 100644
--- a/examples/research_projects/distillation/train.py
+++ b/examples/research_projects/distillation/train.py
@@ -207,8 +207,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
@@ -226,8 +228,8 @@ def main():
if os.path.exists(args.dump_path):
if not args.force:
raise ValueError(
- f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
- "Use `--force` if you want to overwrite it"
+ f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
+ " itUse `--force` if you want to overwrite it"
)
else:
shutil.rmtree(args.dump_path)
diff --git a/examples/research_projects/fsner/src/fsner/tokenizer_utils.py b/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
index 6e4027a9891d..bc5f6650ccd9 100644
--- a/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
+++ b/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
@@ -48,7 +48,8 @@ def tokenize(self, x):
else:
raise Exception(
- "Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of strings` for supports are supported."
+ "Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of"
+ " strings` for supports are supported."
)
return d
diff --git a/examples/research_projects/information-gain-filtration/README.md b/examples/research_projects/information-gain-filtration/README.md
new file mode 100644
index 000000000000..bf95cb8ea814
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/README.md
@@ -0,0 +1,100 @@
+
+# Information Gain Filtration(IGF)
+
+Authors @Tuko @mraunak
+
+This folder contains the code how to implement IGF for finetuning on GPT-2.
+
+## What is IGF?
+
+Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
+performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
+a secondary model (e.g., a simple convolutional network) to predict the amount of information
+gained over a given pre-trained model. The secondary model is lightweight and trained to
+predict the Information Gain measure. Information Gain is defined as the change in a loss
+function for a model before and after an SGD update with a sample (Equation X in the paper).
+A small subset of the training set named the āobjectiveā set, is used to measure information
+gain on the pre-trained model, and consequently to train the secondary model. After
+training, the model is used for filtering samples for the fine-tuning process. Therefore,
+a high information gain value would suggest a sample is informative, whereas a low value
+would suggest a non-informative sample that should be filtered out. Thus, a thresholding
+strategy is defined to select informative samples. With such a strategy, samples are filtered
+and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
+step is applied. The filtration process is repeated until the fine-tuning process is over.
+
+Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
+
+# Results
+
+Several experiments were conducted to show the robustness of the IGF method versus the
+standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
+Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
+implemented using the Transformers library and Pytorch. While the method may seem more
+expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
+
+
+
+Figure 1: Comparing IGF to Standard Fine-tuning:
+IGF with constant (p < 10ā3 , t-test) and shifting(p < 10ā6 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
+test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote ± one standard error). The right-hand figure shows the perplexity of each
+method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
+
+## How to use this project?
+
+To fine-tune a transformer model with IGF on a language modeling task, use the following script:
+
+- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
+- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
+ train_dataset and test_dataset
+- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
+- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded.
+- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
+- `min_len`: The minimum length of the article to be used as objective set
+- `trim`: Truncate the example if it exceeds context length
+- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
+- `max_steps`: To calculate training epochs
+- `number`: The number of examples split to be used as objective_set/test_data
+- `secondary_learner_batch_size`: The batch size of training data for secondary learner
+- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
+- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
+- `eval_interval`: Decay the selectivity of our secondary learner filter from"
+ 1 standard deviation above average to 1 below average after eval_interval(10) batches"
+
+
+```python
+python run_clm_igf.py\
+--model_name_or_path "gpt2" \
+--data_file="data/tokenized_stories_train_wikitext103" \
+--igf_data_file="data/IGF_values" \
+--context_len 32 \
+--size_objective_set 100 \
+--min_len 1026 \
+--trim True \
+--eval_freq 100 \
+--max_steps 1000 \
+--secondary_learner_batch_size 128 \
+--secondary_learner_max_epochs 15 \
+--number 100 \
+--recopy_model \
+--eval_interval 10 \
+```
+
+## Citation
+
+If you find the resource useful, please cite the following paper
+
+```
+@inproceedings{antonello-etal-2021-selecting,
+ title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
+ author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
+ booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
+ month = aug,
+ year = "2021",
+ address = "Online",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/2021.acl-long.87",
+ doi = "10.18653/v1/2021.acl-long.87",
+ pages = "1072--1085",
+}
+```
diff --git a/examples/research_projects/information-gain-filtration/igf/__init__.py b/examples/research_projects/information-gain-filtration/igf/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/examples/research_projects/information-gain-filtration/igf/igf.py b/examples/research_projects/information-gain-filtration/igf/igf.py
new file mode 100644
index 000000000000..99bd8c2d06d7
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/igf/igf.py
@@ -0,0 +1,419 @@
+# Copyright 2022 - Intel Corp. All rights reserved.
+# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
+
+import copy
+import logging
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import joblib
+from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
+
+
+logger = logging.getLogger(__name__)
+
+
+def set_seed(seed):
+ """
+ For reproducible training
+
+ Args:
+ seed: A seed for reproducible training
+
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def compute_perplexity(model, test_data, context_len):
+ """
+ Computes perplexity of the transformer model on data in test_data
+
+ Args:
+ model: Pre-trained GPT2 model
+ test_data: Data on which perplexity calculation is required
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+
+ Returns:
+ Perplexity on input test data
+
+ """
+
+ model.eval()
+ device = next(model.parameters()).device
+ eval_batch_size = 1
+ context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
+ eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
+ eval_loss = torch.zeros(1, device=device)
+ nb_eval_examples = 0
+ for batch in eval_dataloader:
+ batch.to(device)
+ # pad
+ context.zero_()
+ for i in range(eval_batch_size):
+ context[i, :] = batch[i]
+ outputs = model(context, labels=context)
+ eval_loss += outputs[0].sum().item()
+ nb_eval_examples += batch.size(0)
+ eval_loss = eval_loss / nb_eval_examples
+ perplexity = torch.exp(eval_loss)
+ model.train()
+ return perplexity
+
+
+def load_gpt2(model_name="gpt2"):
+ """
+ load original gpt2 and save off for quicker loading
+
+ Args:
+ model_name: GPT-2
+
+ Returns:
+ GPT-2 model
+
+ """
+
+ model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
+ torch.save(model.state_dict(), model_name + "local.pt")
+ return model
+
+
+def recopy_gpt2(orig_model, device, max_steps):
+ """
+ Reset the model to the original pretrained GPT-2 weights after each iteration
+
+ Args:
+ orig_model: Original pretrained GPT-2 model imported from Transformers library
+ device: CPU/GPU
+ max_steps: number of training steps
+
+ Returns:
+ Original PreTrained GPT-2 model,
+ lm_optimizer: Adam optimizer with Decoupled weight decay
+ lm_scheduler: linear scheduler with the appropriate schedule
+
+ """
+ model = copy.deepcopy(orig_model)
+ model.to(device)
+
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
+ ]
+ lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
+ lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
+ torch.cuda.empty_cache()
+ return model, lm_optimizer, lm_scheduler
+
+
+def intermittent_save(contexts, real_perps, past_perps, filename):
+
+ """
+ save the perplexity differences to filename
+
+ Args:
+ contexts: Example on which the perplexity is calculated
+ real_perps: Perplexity after back-propagating on the selected context
+ past_perps: Perplexity of model before training on the context
+ filename: File to store perplexity differences
+
+ Returns:
+ file with perplexity differences
+
+ """
+ # save the perplexity differences to filename
+ avg = np.array(real_perps).mean()
+ std = np.array(real_perps).std()
+ perp_diff = (real_perps - avg) / std
+ data_final = list(zip(contexts, perp_diff, past_perps))
+ joblib.dump(data_final, filename)
+
+
+def collect_objective_set(
+ model,
+ orig_perp,
+ context_len,
+ train_data,
+ objective_set,
+ max_steps,
+ device,
+ filename="dev.jbl",
+ recopy_model=recopy_gpt2,
+):
+
+ """
+ Collect individual IGF values from pre-trained transformer model
+ max_steps samples of training data to train secondary model
+
+ Args:
+ model: Pre-trained GPT2 model
+ orig_perp: Perplexity of original pretrained GPT-2 model
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ train_data: Data to train model
+ objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
+ max_steps: To calculate training epochs of model
+ device: GPU/CPU
+ filename: To store intermediate perplexity differences
+ recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
+
+ Returns:
+ file stored intermediate perplexity differences in intermediate stages
+
+ """
+
+ # initialize variables to record relevant information
+ contexts = []
+ real_perps = []
+ past_perps = []
+
+ # Initialize the transformer model
+ orig_model = copy.deepcopy(model)
+ orig_model.to(device="cpu")
+ torch.cuda.empty_cache()
+
+ # Compute perplexity of initial transformer model for comparison
+ model.train()
+ model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
+
+ for step in tqdm(range(max_steps)):
+ context = torch.zeros((1, context_len), dtype=torch.long, device=device)
+ story = random.choice(train_data)
+ start = random.randint(0, len(story[0]) - context_len - 1)
+ context[0, :] = story[0][start : start + context_len]
+ lm_optimizer.zero_grad()
+ outputs = model(context, labels=context)
+ lm_loss = outputs[0]
+ past_perp = compute_perplexity(model, context, context_len)
+ model.train()
+ lm_loss.backward()
+ # Do LM backprop
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
+ lm_optimizer.step()
+ lm_scheduler.step() # Update learning rate schedule
+
+ # Compute perplexity after back-propagating on the selected context
+ real_perp = compute_perplexity(model, objective_set, context_len)
+
+ # Periodically save the stored (X, IG(X)) pairs
+ if step % 1000 == 0 and step > 1:
+ intermittent_save(contexts, real_perps, past_perps, filename)
+
+ # Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
+ model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
+
+ past_perps.append(past_perp.item())
+ real_perps.append(orig_perp - real_perp.item())
+ contexts.append(np.array(context.cpu()))
+
+ intermittent_save(contexts, real_perps, past_perps, filename)
+
+
+def generate_datasets(
+ context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
+):
+ """
+ Generate objective set and training set
+
+ Args:
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ file: Tokenized data split into training set and objective set
+ number: size of objective dataset
+ min_len: minimum length of a context in objective set
+ trim: If True truncate the context if it exceeds context length
+
+ Returns:
+ Generated objective set and training data
+
+
+ """
+ # Generate objective set and training set
+ # Designate the first number (100) articles that are long enough to be used
+ # as our objective set, rest (that are long enough) are training data for
+ # secondary learner
+
+ data = joblib.load(file)
+ print("data loaded")
+ objective_set = []
+ if trim:
+ for i, example in enumerate(data):
+ if len(example[0]) > min_len:
+ start = random.randint(0, len(example[0]) - context_len - 1)
+ objective_set.append(example[0, start : start + context_len])
+ if len(objective_set) >= number:
+ break
+ train_data = []
+ for j in range(i + 1, len(data)):
+ if len(data[j][0]) > min_len:
+ train_data.append(data[j])
+ else:
+ objective_set = data[0:number]
+ train_data = data[number:]
+
+ joblib.dump(objective_set, "objective_set.jbl")
+ print("objective set saved")
+ return train_data, objective_set
+
+
+def train_secondary_learner(
+ secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
+):
+
+ """
+ Train the secondary learner (igf_model)
+
+ Args:
+ secondary_learner: secondary learner
+ train_dataset: data to train secondary learner
+ max_epochs: number of epochs to train secondary learner
+ batch_size: batch size of training data of secondary learner
+ eval_freq: secondary model evaluation can be triggered at eval_freq
+ igf_model_path: path to store trained secondary learner
+
+ Returns:
+ Trained secondary learner
+
+ """
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # We will use the first 512 pairs from our dataset as a test set for
+ # our secondary learner and the rest to train
+ test_dataset = train_dataset[:512]
+ train_dataset = train_dataset[512:]
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
+ test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
+
+ # secondary learner model set up
+ loss = nn.MSELoss()
+ test_loss = nn.MSELoss(reduction="sum")
+ secondary_learner.to(device)
+ q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
+ secondary_learner.train()
+
+ # TODO in original code this is written as number of actual batches seen
+ # not number of items seen but other places it is number of items instead.
+ # improve consistency! changed this to epochs for clarity
+ best_test_loss = float("inf")
+ # Iterate through batches until we've used max_steps batches
+ for epoch in range(int(max_epochs)):
+ tr_q_loss = 0.0
+ secondary_learner.train()
+ for step, batch in enumerate(train_dataloader):
+ context = batch[0].to(device)
+ real_q = batch[1].to(device)
+ predicted_q = secondary_learner(context)
+ q_optimizer.zero_grad()
+ q_loss = loss(predicted_q, real_q.float())
+ q_loss.backward()
+ q_optimizer.step()
+ tr_q_loss += q_loss.item()
+
+ # model trains fairly quickly so we won't wait for a full epoch
+ # eval is triggered at eval_freq and end of epochs
+ if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
+ tr_loss = tr_q_loss / (step + 1)
+
+ secondary_learner.eval()
+ q_loss2 = 0.0
+ sum_q2 = 0.0
+ predicted = []
+ actual = []
+ # Compute performance of the secondary learner after this batch
+ for step2, batch2 in enumerate(test_dataloader):
+ features2 = batch2[0].to(device)
+ real_q2 = batch2[1].to(device)
+ predicted_q2 = secondary_learner(features2)
+ q_loss2 += test_loss(predicted_q2, real_q2).item()
+ sum_q2 += torch.sum(predicted_q2).item()
+ for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
+ predicted.append(i.item())
+ for ei, i in enumerate(real_q2.cpu().detach().numpy()):
+ actual.append(i.item())
+
+ q_loss2 /= len(test_dataset)
+ print(
+ "Epoch: ",
+ epoch,
+ "step: ",
+ step,
+ "Avg. q:",
+ sum_q2 / len(test_dataset),
+ "Train Loss: ",
+ tr_loss,
+ "Test Loss: ",
+ q_loss2,
+ )
+ if q_loss2 < best_test_loss:
+ joblib.dump((predicted, actual), "pred_vs_actual.jbl")
+ torch.save(secondary_learner.state_dict(), igf_model_path)
+ best_test_loss = q_loss2
+
+ secondary_learner.train()
+ return secondary_learner
+
+
+class SecondaryLearner(nn.Module):
+ """
+ Our secondary learner
+ """
+
+ def __init__(self, model):
+ """
+ We use a simple convolutional network as our secondary learner
+
+ Args:
+ model: Pre-trained GPT2 model
+ """
+ # embeddings are from the pretrained model
+ super(SecondaryLearner, self).__init__()
+ self.embeddings = model.transformer.wte
+ self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
+ self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
+ self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
+
+ def forward(self, context):
+ """
+ Forward pass through the secondary learner
+
+ Args:
+ context: Context input to the secondary learner
+
+ Returns:
+ tensor after squeeze operation
+
+ """
+ pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
+ qs = self.fc(pooled)
+ return qs.squeeze(1)
+
+ @classmethod
+ def from_pretrained(cls, state_path, model):
+ """
+ Load the secondary learner
+
+ Args:
+ state_path: Path to save secondary learner
+ model: Pretrained GPT-2
+
+ Returns:
+ secondary learner
+ """
+
+ secondary_learner = cls(model) # this calls __init__
+ state_dict = torch.load(state_path)
+ secondary_learner.load_state_dict(state_dict)
+ secondary_learner.embeddings = model.transformer.wte
+ secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
+ return secondary_learner
diff --git a/examples/research_projects/information-gain-filtration/requirements.txt b/examples/research_projects/information-gain-filtration/requirements.txt
new file mode 100644
index 000000000000..2aa3227637c8
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/requirements.txt
@@ -0,0 +1,6 @@
+matplotlib
+numpy>=1.17.2
+joblib>=0.13.2
+scipy
+torch>=1.10.1
+transformers>=3.5
\ No newline at end of file
diff --git a/examples/research_projects/information-gain-filtration/result_igf.png b/examples/research_projects/information-gain-filtration/result_igf.png
new file mode 100644
index 000000000000..10bb0b7d6816
Binary files /dev/null and b/examples/research_projects/information-gain-filtration/result_igf.png differ
diff --git a/examples/research_projects/information-gain-filtration/run_clm_igf.py b/examples/research_projects/information-gain-filtration/run_clm_igf.py
new file mode 100644
index 000000000000..eae10060b22f
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/run_clm_igf.py
@@ -0,0 +1,446 @@
+# Copyright 2022 - Intel Corp. All rights reserved.
+# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
+
+"""
+Implementation of a new method for fine-tuning transformer models that we call
+Information Gain Filtration 'IGF' on WikiText data set and compared the results
+with the standard fine-tuning method
+
+Steps followed in the code:
+
+1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
+Our IG (information gain) model is learning to predict the āinformativenessā of a particular
+context. Informativeness is the change in metric between the modelās accuracy on an
+objective set before and after seeing that context. For casual language modeling, the
+metric is perplexity.
+
+2) A secondary learner is trained to infer a function approximation for IG using the dataset
+created in (1).
+
+3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
+
+Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
+
+"""
+
+# Prerequisite libraries:
+
+import argparse
+import random
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, RandomSampler
+
+import joblib
+from igf.igf import (
+ SecondaryLearner,
+ collect_objective_set,
+ compute_perplexity,
+ generate_datasets,
+ load_gpt2,
+ recopy_gpt2,
+ set_seed,
+ train_secondary_learner,
+)
+from transformers import GPT2LMHeadModel
+
+
+def generate_n_pairs(
+ context_len=32,
+ max_steps=10,
+ size_objective_set=100,
+ min_len=1026,
+ trim=True,
+ data_file="data/tokenized_stories_train_wikitext103.jbl",
+ igf_data_file="igf_context_pairs.jbl",
+):
+
+ """
+ Collecting *n* pairs for training the secondary learner
+ Args:
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ max_steps: To calculate training epochs of secondary learner
+ size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
+ min_len: The minimum length of the article to be used as objective set
+ trim: If True truncate the context if it exceeds context length
+ data_file: Tokenized data set split for training and evaluation of model
+ igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
+
+ Returns:
+ Data stored in igf_data_file
+
+ """
+ # generates same data everytime
+ set_seed(3)
+ # generate train_data and objective_set
+ train_data, objective_set = generate_datasets(
+ context_len, data_file, number=size_objective_set, min_len=1026, trim=True
+ )
+ # keeps model same across runs
+ set_seed(4)
+ # model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
+ # can we train on GPU?
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ # load pretrained model
+ model = load_gpt2("gpt2").to(device)
+ print("computing perplexity on objective set")
+ orig_perp = compute_perplexity(model, objective_set, context_len).item()
+ print("perplexity on objective set:", orig_perp)
+
+ # collect igf pairs and save to file demo.jbl
+ collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
+
+ # clean up, delete model and data we don't need anymore
+ del model, train_data, objective_set
+ torch.cuda.empty_cache()
+
+
+def training_secondary_learner(
+ secondary_learner_train_data,
+ secondary_learner_max_epochs=15,
+ secondary_learner_batch_size=128,
+ eval_freq=100,
+ igf_model_path="igf_model.pt",
+):
+ """
+ Train the secondary learner
+
+ Args:
+ secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
+ secondary_learner_max_epochs: Number of epochs to train secondary learner
+ secondary_learner_batch_size: Batch size to train secondary learner
+ eval_freq (object): secondary model evaluation can be triggered at eval_freq
+ igf_model_path: path to store trained secondary learner
+
+ Returns:
+ Trained secondary learner
+ """
+
+ set_seed(42)
+
+ # Load pre-trained model
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
+
+ # Initialize secondary learner to use embedding weights of model
+ secondary_learner = SecondaryLearner(model)
+
+ # Train secondary learner
+ secondary_learner = train_secondary_learner(
+ secondary_learner,
+ secondary_learner_train_data,
+ max_epochs=secondary_learner_max_epochs,
+ batch_size=secondary_learner_batch_size,
+ eval_freq=100,
+ igf_model_path=igf_model_path,
+ )
+
+ del model, secondary_learner_train_data
+ torch.cuda.empty_cache()
+
+ return secondary_learner
+
+
+def finetune(
+ model,
+ train_dataset,
+ test_dataset,
+ context_len=32,
+ max_steps=1000,
+ batch_size=16,
+ threshold=1.0,
+ recopy_model=recopy_gpt2,
+ secondary_learner=None,
+ eval_interval=10,
+ finetuned_model_name="gpt2_finetuned.pt",
+):
+ """
+ fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
+
+ Args:
+ model: pre-trained GPT-2 model
+ train_dataset: Data set to train GPT-2 model
+ test_dataset: Evaluate GPT-2 model
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ max_steps: To calculate training epochs
+ batch_size: Batch size to train GPT-2 model
+ threshold: The threshold value used by secondary learner to filter the train_data and allow only"
+ informative data as input to the model
+ recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
+ secondary_learner: Selection of IGF as fine-tuning method if not None
+ eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
+ 1 standard deviation above average to 1 below average
+ fine-tuned_model_name: name of the final final-tuned GPT-2 model
+
+ Returns:
+ Fine-tuned GPT-2 model
+
+ """
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ train_sampler = RandomSampler(train_dataset)
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
+
+ num_train_epochs = max_steps // (len(train_dataset)) + 1
+ global_step = 0
+ context = torch.zeros((1, context_len), dtype=torch.long, device=device)
+ model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
+
+ model.train()
+ if secondary_learner is not None:
+ secondary_learner.to(device)
+ secondary_learner.eval()
+ contexts = []
+ examples = 0
+
+ observed_qs = []
+ test_perps = []
+
+ # Compute the performance of the transformer model at the beginning
+ real_perp = compute_perplexity(model, test_dataset, context_len)
+ test_perps.append(real_perp)
+ print("Test perplexity, step", global_step, ":", real_perp)
+ for epoch in range(int(num_train_epochs)):
+ for step, example in enumerate(train_dataloader):
+ torch.cuda.empty_cache()
+ start = random.randint(0, example.size(2) - context_len - 1)
+ context[0, :] = example[0, 0, start : start + context_len]
+ lm_optimizer.zero_grad()
+ outputs = model(context, labels=context)
+ do_backprop = True
+
+ if secondary_learner is not None:
+ predicted_q = secondary_learner.forward(
+ torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
+ )[0].item()
+ observed_qs.append(float(predicted_q))
+
+ # Here we implement the simple non-constant threshold for the predicted IG(X) value
+ # We will decay the selectivity of our secondary learner filter from
+ # 1 standard deviation above average to 1 below average after 10 batches.
+
+ if global_step == 10:
+ threshold = -1
+ if predicted_q < threshold:
+ do_backprop = False
+
+ # If we passed the filter, add the context to the batch!
+ if do_backprop:
+ contexts.append(np.array(context.cpu()))
+ lm_loss = outputs[0]
+ lm_loss.backward()
+ examples += 1
+
+ del outputs
+
+ # Once the batch is filled with enough contexts, backprop on the batch.
+ if examples == batch_size:
+ torch.cuda.empty_cache()
+ examples = 0
+ # Do LM backprop
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
+ lm_optimizer.step()
+ lm_scheduler.step() # Update learning rate schedule
+ global_step += 1
+ # Compute the performance of the transformer model at this batch
+ if global_step % eval_interval == 0:
+ real_perp = compute_perplexity(model, test_dataset, context_len)
+ test_perps.append(real_perp)
+
+ print("Test perplexity, step", global_step, ":", real_perp)
+ # Break out of the loop after 60 batches
+ if max_steps > 0 and global_step > 60:
+ break
+ if max_steps > 0 and global_step > 60:
+ break
+
+ # save finetuned transformer model
+ torch.save(model.state_dict(), finetuned_model_name)
+ torch.cuda.empty_cache()
+ # Do some cleaning up so we can reinitialize for the next run of this function
+ del lm_optimizer
+ del lm_scheduler
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
+
+ # Required parameters
+ parser.add_argument(
+ "--data_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The input data dir. Should contain data files for WikiText.",
+ )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models",
+ )
+ parser.add_argument(
+ "--data_file",
+ type=str,
+ default=None,
+ help=(
+ "A jbl file containing tokenized data which can be split as objective dataset, "
+ "train_dataset and test_dataset."
+ ),
+ )
+
+ parser.add_argument(
+ "--igf_data_file",
+ type=str,
+ default=None,
+ help="A jbl file containing the context and information gain pairs to train secondary learner.",
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the final fine-tuned model is stored.",
+ )
+
+ parser.add_argument(
+ "--tokenizer_name",
+ default=None,
+ type=str,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+
+ parser.add_argument(
+ "--context_len",
+ default=32,
+ type=int,
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
+ )
+
+ parser.add_argument(
+ "--size_objective_set",
+ default=100,
+ type=int,
+ help="number of articles that are long enough to be used as our objective set",
+ )
+ parser.add_argument(
+ "--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
+ )
+
+ parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
+
+ parser.add_argument(
+ "--secondary_learner_batch_size",
+ default=128,
+ type=int,
+ help="batch size of training data for secondary learner",
+ )
+
+ parser.add_argument(
+ "--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
+ )
+
+ parser.add_argument(
+ "--eval_interval",
+ default=10,
+ type=int,
+ help=(
+ "decay the selectivity of our secondary learner filter from"
+ "1 standard deviation above average to 1 below average after 10 batches"
+ ),
+ )
+
+ parser.add_argument(
+ "--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
+ )
+
+ parser.add_argument(
+ "--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
+ )
+
+ parser.add_argument(
+ "--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
+ )
+
+ parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
+
+ parser.add_argument(
+ "--threshold",
+ default=1.0,
+ type=float,
+ help=(
+ "The threshold value used by secondary learner to filter the train_data and allow only"
+ " informative data as input to the model"
+ ),
+ )
+
+ parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")
+
+ parser.add_argument(
+ "--recopy_model",
+ default=recopy_gpt2,
+ type=str,
+ help="Reset the model to the original pretrained GPT-2 weights after each iteration",
+ )
+
+ # function calls
+ # Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
+ generate_n_pairs(
+ context_len=32,
+ max_steps=10,
+ size_objective_set=100,
+ min_len=1026,
+ trim=True,
+ data_file="data/tokenized_stories_train_wikitext103.jbl",
+ igf_data_file="igf_context_pairs.jbl",
+ )
+
+ # Load train data for secondary learner
+ secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
+
+ # Train secondary learner
+ secondary_learner = training_secondary_learner(
+ secondary_learner_train_data,
+ secondary_learner_max_epochs=15,
+ secondary_learner_batch_size=128,
+ eval_freq=100,
+ igf_model_path="igf_model.pt",
+ )
+
+ # load pretrained gpt2 model
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
+ set_seed(42)
+
+ # Generate train and test data to train and evaluate gpt2 model
+ train_dataset, test_dataset = generate_datasets(
+ context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
+ )
+
+ # fine-tuning of the gpt2 model using igf (Information Gain Filtration)
+ finetune(
+ model,
+ train_dataset,
+ test_dataset,
+ context_len=32,
+ max_steps=1000,
+ batch_size=16,
+ threshold=1.0,
+ recopy_model=recopy_gpt2,
+ secondary_learner=secondary_learner,
+ eval_interval=10,
+ finetuned_model_name="gpt2_finetuned.pt",
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/jax-projects/README.md b/examples/research_projects/jax-projects/README.md
index 56316ef940a1..0b3f0dc5d24f 100644
--- a/examples/research_projects/jax-projects/README.md
+++ b/examples/research_projects/jax-projects/README.md
@@ -49,7 +49,7 @@ At the end of the community week, each team should submit a demo of their projec
- **23.06.** Official announcement of the community week. Make sure to sign-up in [this google form](https://forms.gle/tVGPhjKXyEsSgUcs8).
- **23.06. - 30.06.** Participants will be added to an internal Slack channel. Project ideas can be proposed here and groups of 3-5 are formed. Read this document for more information.
-- **30.06.** Release of all relevant training scripts in JAX/Flax as well as other documents on how to set up a TPU, how to use the training scripts, how to submit a demo, tips & tricks for JAX/Flax, tips & tricks for efficient use of the hub.
+- **30.06.** Release of all relevant training scripts in JAX/Flax as well as other documents on how to set up a TPU, how to use the training scripts, how to submit a demo, tips & tricks for JAX/Flax, tips & tricks for efficient use of the hub.
- **30.06. - 2.07.** Talks about JAX/Flax, TPU, Transformers, Computer Vision & NLP will be held.
- **7.07.** Start of the community week! Access to TPUv3-8 will be given to each team.
- **7.07. - 14.07.** The Hugging Face & JAX/Flax & Cloud team will be available for any questions, problems the teams might run into.
diff --git a/examples/research_projects/jax-projects/big_bird/evaluate.py b/examples/research_projects/jax-projects/big_bird/evaluate.py
index de01e8fc81a3..e3309f494e34 100644
--- a/examples/research_projects/jax-projects/big_bird/evaluate.py
+++ b/examples/research_projects/jax-projects/big_bird/evaluate.py
@@ -106,7 +106,7 @@ def forward(*args, **kwargs):
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1)
def evaluate(example):
- # encode question and context so that they are seperated by a tokenizer.sep_token and cut at max_length
+ # encode question and context so that they are separated by a tokenizer.sep_token and cut at max_length
inputs = tokenizer(
example["question"],
example["context"],
diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
index 0bb4a7b9c514..fadcec09cbf0 100755
--- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
+++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
@@ -75,8 +75,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -99,7 +100,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -141,8 +145,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -155,8 +161,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
line_by_line: bool = field(
@@ -256,7 +264,7 @@ def mask_tokens(
return inputs, labels
-def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
@@ -280,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
tokenized_samples = next(train_iterator)
i += len(tokenized_samples["input_ids"])
- # concatenate tokenized samples to list
- samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
+ # concatenate tokenized samples to list (excluding "id" and "text")
+ samples = {
+ k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
+ }
# Concatenated tokens are split to lists of length `max_seq_length`.
# Note that remainedr of % max_seq_length are thrown away.
@@ -399,10 +409,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def tokenize_function(examples):
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
- tokenized_datasets = dataset.map(
- tokenize_function,
- batched=True,
- )
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
shuffle_seed = training_args.seed
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
@@ -575,7 +582,8 @@ def eval_step(params, batch):
if step % training_args.logging_steps == 0 and step > 0:
steps.write(
- f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
@@ -584,7 +592,8 @@ def eval_step(params, batch):
# ======================== Evaluating ==============================
if step % training_args.eval_steps == 0 and step > 0:
- eval_samples_idx = jnp.arange(data_args.num_eval_samples)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(data_args.num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
@@ -604,7 +613,10 @@ def eval_step(params, batch):
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
- steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
+ steps.desc = (
+ f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
+ f" {eval_metrics['accuracy']})"
+ )
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, step)
diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
index 0572a4e019a8..6ee974666a29 100644
--- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
+++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
@@ -77,14 +77,18 @@ class ModelArguments:
text_model_name_or_path: str = field(
metadata={
- "help": "The text model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The text model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ )
},
)
vision_model_name_or_path: str = field(
metadata={
- "help": "The vision model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The vision model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ )
},
)
from_pt: bool = field(
@@ -107,7 +111,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -129,22 +136,28 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=72,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -519,7 +532,8 @@ def eval_step(params, batch):
train_step_progress_bar.close()
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
diff --git a/examples/research_projects/jax-projects/model_parallel/README.md b/examples/research_projects/jax-projects/model_parallel/README.md
index 6b6998b56ad2..b63b93862db0 100644
--- a/examples/research_projects/jax-projects/model_parallel/README.md
+++ b/examples/research_projects/jax-projects/model_parallel/README.md
@@ -22,7 +22,7 @@ the JAX/Flax backend and the [`pjit`](https://jax.readthedocs.io/en/latest/jax.e
> Note: The example is experimental and might have bugs. Also currently it only supports single V3-8.
The `partition.py` file defines the `PyTree` of `ParitionSpec` for the GPTNeo model which describes how the model will be sharded.
-The actual sharding is auto-matically handled by `pjit`. The weights are sharded accross all local devices.
+The actual sharding is auto-matically handled by `pjit`. The weights are sharded across all local devices.
To adapt the script for other models, we need to also change the `ParitionSpec` accordingly.
TODO: Add more explantion.
diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
index 3371dc3bd4df..518ef9f7b22f 100644
--- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
+++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
@@ -69,8 +69,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -93,7 +94,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -118,15 +122,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -141,9 +149,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -334,7 +344,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
@@ -606,7 +617,8 @@ def eval_step(input_ids, labels, params):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -632,7 +644,8 @@ def eval_step(input_ids, labels, params):
eval_metrics["perplexity"] = float("inf")
logger.info(
- f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
+ f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity:"
+ f" {eval_metrics['perplexity']}"
)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
index e2bcd7861bec..457c58d44fde 100755
--- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
+++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
@@ -64,7 +64,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -94,7 +97,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
speech_file_column: Optional[str] = field(
@@ -120,7 +125,10 @@ class DataTrainingArguments:
pad_to_multiple_of: Optional[int] = field(
default=1024,
metadata={
- "help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
+ "help": (
+ "If set will pad the sequence to a multiple of the provided value. This is important to avoid"
+ " triggering recompilations on TPU"
+ )
},
)
@@ -229,7 +237,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
summary_writer.scalar(f"eval_{metric_name}", value, step)
-def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
@@ -357,7 +365,8 @@ def normalize(batch):
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
@@ -532,7 +541,8 @@ def eval_step(params, batch):
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(vectorized_datasets["train"])
- train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
+ # Avoid using jax.numpy here in case of TPU training
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
@@ -557,14 +567,16 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
# ======================== Evaluating ==============================
num_eval_samples = len(vectorized_datasets["validation"])
- eval_samples_idx = jnp.arange(num_eval_samples)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
@@ -583,7 +595,8 @@ def eval_step(params, batch):
# Update progress bar
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity:"
+ f" {eval_metrics['codevector_perplexity']})"
)
# Save metrics
diff --git a/examples/research_projects/layoutlmv3/README.md b/examples/research_projects/layoutlmv3/README.md
new file mode 100644
index 000000000000..17bf4bb67cd9
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/README.md
@@ -0,0 +1,69 @@
+
+
+# Token classification with LayoutLMv3 (PyTorch version)
+
+This directory contains a script, `run_funsd_cord.py`, that can be used to fine-tune (or evaluate) LayoutLMv3 on form understanding datasets, such as [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [CORD](https://github.com/clovaai/cord).
+
+The script `run_funsd_cord.py` leverages the š¤ Datasets library and the Trainer API. You can easily customize it to your needs.
+
+## Fine-tuning on FUNSD
+
+Fine-tuning LayoutLMv3 for token classification on [FUNSD](https://guillaumejaume.github.io/FUNSD/) can be done as follows:
+
+```bash
+python run_funsd_cord.py \
+ --model_name_or_path microsoft/layoutlmv3-base \
+ --dataset_name funsd \
+ --output_dir layoutlmv3-test \
+ --do_train \
+ --do_eval \
+ --max_steps 1000 \
+ --evaluation_strategy steps \
+ --eval_steps 100 \
+ --learning_rate 1e-5 \
+ --load_best_model_at_end \
+ --metric_for_best_model "eval_f1" \
+ --push_to_hub \
+ --push_to_hub°model_id layoutlmv3-finetuned-funsd
+```
+
+š The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd. By specifying the `push_to_hub` flag, the model gets uploaded automatically to the hub (regularly), together with a model card, which includes metrics such as precision, recall and F1. Note that you can easily update the model card, as it's just a README file of the respective repo on the hub.
+
+There's also the "Training metrics" [tab](https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd/tensorboard), which shows Tensorboard logs over the course of training. Pretty neat, huh?
+
+## Fine-tuning on CORD
+
+Fine-tuning LayoutLMv3 for token classification on [CORD](https://github.com/clovaai/cord) can be done as follows:
+
+```bash
+python run_funsd_cord.py \
+ --model_name_or_path microsoft/layoutlmv3-base \
+ --dataset_name cord \
+ --output_dir layoutlmv3-test \
+ --do_train \
+ --do_eval \
+ --max_steps 1000 \
+ --evaluation_strategy steps \
+ --eval_steps 100 \
+ --learning_rate 5e-5 \
+ --load_best_model_at_end \
+ --metric_for_best_model "eval_f1" \
+ --push_to_hub \
+ --push_to_hub°model_id layoutlmv3-finetuned-cord
+```
+
+š The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-cord. Note that a model card gets generated automatically in case you specify the `push_to_hub` flag.
\ No newline at end of file
diff --git a/examples/research_projects/layoutlmv3/requirements.txt b/examples/research_projects/layoutlmv3/requirements.txt
new file mode 100644
index 000000000000..504a8cc9870f
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/requirements.txt
@@ -0,0 +1,2 @@
+datasets
+seqeval
\ No newline at end of file
diff --git a/examples/research_projects/layoutlmv3/run_funsd_cord.py b/examples/research_projects/layoutlmv3/run_funsd_cord.py
new file mode 100644
index 000000000000..866f9a9c1b11
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/run_funsd_cord.py
@@ -0,0 +1,533 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fine-tuning LayoutLMv3 for token classification on FUNSD or CORD.
+"""
+# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
+# comments.
+
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from typing import Optional
+
+import datasets
+import numpy as np
+from datasets import ClassLabel, load_dataset, load_metric
+
+import transformers
+from transformers import (
+ AutoConfig,
+ AutoModelForTokenClassification,
+ AutoProcessor,
+ HfArgumentParser,
+ Trainer,
+ TrainingArguments,
+ set_seed,
+)
+from transformers.data.data_collator import default_data_collator
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version
+from transformers.utils.versions import require_version
+
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.19.0.dev0")
+
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+
+ model_name_or_path: str = field(
+ default="microsoft/layoutlmv3-base",
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
+ )
+ config_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+ )
+ processor_name: Optional[str] = field(
+ default=None, metadata={"help": "Name or path to the processor files if not the same as model_name"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ use_auth_token: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
+ },
+ )
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
+ dataset_name: Optional[str] = field(
+ default="nielsr/funsd-layoutlmv3",
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
+ )
+ train_file: Optional[str] = field(
+ default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
+ )
+ validation_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
+ )
+ test_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
+ )
+ text_column_name: Optional[str] = field(
+ default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
+ )
+ label_column_name: Optional[str] = field(
+ default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ )
+ preprocessing_num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of processes to use for the preprocessing."},
+ )
+ max_seq_length: int = field(
+ default=512,
+ metadata={
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
+ },
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_predict_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
+ },
+ )
+ label_all_tokens: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
+ },
+ )
+ return_entity_level_metrics: bool = field(
+ default=False,
+ metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
+ )
+
+ def __post_init__(self):
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
+ raise ValueError("Need either a dataset name or a training/validation file.")
+ else:
+ if self.train_file is not None:
+ extension = self.train_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
+ if self.validation_file is not None:
+ extension = self.validation_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
+ self.task_name = self.task_name.lower()
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Detecting last checkpoint.
+ last_checkpoint = None
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Get the datasets
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+ # download the dataset.
+ if data_args.dataset_name == "funsd":
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ "nielsr/funsd-layoutlmv3",
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ elif data_args.dataset_name == "cord":
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ "nielsr/cord-layoutlmv3",
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ raise ValueError("This script only supports either FUNSD or CORD out-of-the-box.")
+
+ if training_args.do_train:
+ column_names = dataset["train"].column_names
+ features = dataset["train"].features
+ else:
+ column_names = dataset["test"].column_names
+ features = dataset["test"].features
+
+ image_column_name = "image"
+ text_column_name = "words" if "words" in column_names else "tokens"
+ boxes_column_name = "bboxes"
+ label_column_name = (
+ f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
+ )
+
+ remove_columns = column_names
+
+ # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
+ # unique labels.
+ def get_label_list(labels):
+ unique_labels = set()
+ for label in labels:
+ unique_labels = unique_labels | set(label)
+ label_list = list(unique_labels)
+ label_list.sort()
+ return label_list
+
+ # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
+ # Otherwise, we have to get the list of labels manually.
+ if isinstance(features[label_column_name].feature, ClassLabel):
+ label_list = features[label_column_name].feature.names
+ # No need to convert the labels since they are already ints.
+ id2label = {k: v for k, v in enumerate(label_list)}
+ label2id = {v: k for k, v in enumerate(label_list)}
+ else:
+ label_list = get_label_list(datasets["train"][label_column_name])
+ id2label = {k: v for k, v in enumerate(label_list)}
+ label2id = {v: k for k, v in enumerate(label_list)}
+ num_labels = len(label_list)
+
+ # Load pretrained model and processor
+ #
+ # Distributed training:
+ # The .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+ config = AutoConfig.from_pretrained(
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+ num_labels=num_labels,
+ finetuning_task=data_args.task_name,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ processor = AutoProcessor.from_pretrained(
+ model_args.processor_name if model_args.processor_name else model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=True,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ add_prefix_space=True,
+ apply_ocr=False,
+ )
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ # Set the correspondences label/ID inside the model config
+ model.config.label2id = label2id
+ model.config.id2label = id2label
+
+ # Preprocessing the dataset
+ # The processor does everything for us (prepare the image using LayoutLMv3FeatureExtractor
+ # and prepare the words, boxes and word-level labels using LayoutLMv3TokenizerFast)
+ def prepare_examples(examples):
+ images = examples[image_column_name]
+ words = examples[text_column_name]
+ boxes = examples[boxes_column_name]
+ word_labels = examples[label_column_name]
+
+ encoding = processor(
+ images,
+ words,
+ boxes=boxes,
+ word_labels=word_labels,
+ truncation=True,
+ padding="max_length",
+ max_length=data_args.max_seq_length,
+ )
+
+ return encoding
+
+ if training_args.do_train:
+ if "train" not in dataset:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = dataset["train"]
+ if data_args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
+ train_dataset = train_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ if training_args.do_eval:
+ validation_name = "test"
+ if validation_name not in dataset:
+ raise ValueError("--do_eval requires a validation dataset")
+ eval_dataset = dataset[validation_name]
+ if data_args.max_eval_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
+ eval_dataset = eval_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ if training_args.do_predict:
+ if "test" not in datasets:
+ raise ValueError("--do_predict requires a test dataset")
+ predict_dataset = datasets["test"]
+ if data_args.max_predict_samples is not None:
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
+ predict_dataset = predict_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ # Metrics
+ metric = load_metric("seqeval")
+
+ def compute_metrics(p):
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ results = metric.compute(predictions=true_predictions, references=true_labels)
+ if data_args.return_entity_level_metrics:
+ # Unpack nested dictionaries
+ final_results = {}
+ for key, value in results.items():
+ if isinstance(value, dict):
+ for n, v in value.items():
+ final_results[f"{key}_{n}"] = v
+ else:
+ final_results[key] = value
+ return final_results
+ else:
+ return {
+ "precision": results["overall_precision"],
+ "recall": results["overall_recall"],
+ "f1": results["overall_f1"],
+ "accuracy": results["overall_accuracy"],
+ }
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ tokenizer=processor,
+ data_collator=default_data_collator,
+ compute_metrics=compute_metrics,
+ )
+
+ # Training
+ if training_args.do_train:
+ checkpoint = None
+ if training_args.resume_from_checkpoint is not None:
+ checkpoint = training_args.resume_from_checkpoint
+ elif last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ metrics = train_result.metrics
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+ max_train_samples = (
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
+ )
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+
+ metrics = trainer.evaluate()
+
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ logger.info("*** Predict ***")
+
+ predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ trainer.log_metrics("predict", metrics)
+ trainer.save_metrics("predict", metrics)
+
+ # Save predictions
+ output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
+ if trainer.is_world_process_zero():
+ with open(output_predictions_file, "w") as writer:
+ for prediction in true_predictions:
+ writer.write(" ".join(prediction) + "\n")
+
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"}
+ if data_args.dataset_name is not None:
+ kwargs["dataset_tags"] = data_args.dataset_name
+ if data_args.dataset_config_name is not None:
+ kwargs["dataset_args"] = data_args.dataset_config_name
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
+ else:
+ kwargs["dataset"] = data_args.dataset_name
+
+ if training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(**kwargs)
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/longform-qa/eli5_utils.py b/examples/research_projects/longform-qa/eli5_utils.py
index c14210bd5e58..82c4bd8caf20 100644
--- a/examples/research_projects/longform-qa/eli5_utils.py
+++ b/examples/research_projects/longform-qa/eli5_utils.py
@@ -649,7 +649,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
" " + "
".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
- for (res_passages, dl) in zip(res_passages_lst, D):
+ for res_passages, dl in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
@@ -679,7 +679,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
"
" + "
".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
- for (res_passages, dl, il) in zip(res_passages_lst, D, I):
+ for res_passages, dl, il in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
diff --git a/examples/research_projects/luke/README.md b/examples/research_projects/luke/README.md
index a4eb1370436b..703eb0b4e423 100644
--- a/examples/research_projects/luke/README.md
+++ b/examples/research_projects/luke/README.md
@@ -14,7 +14,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then to train English LUKE on CoNLL2003:
diff --git a/examples/research_projects/luke/run_luke_ner_no_trainer.py b/examples/research_projects/luke/run_luke_ner_no_trainer.py
index c7a9763d9965..cb81402425ff 100644
--- a/examples/research_projects/luke/run_luke_ner_no_trainer.py
+++ b/examples/research_projects/luke/run_luke_ner_no_trainer.py
@@ -101,8 +101,8 @@ def parse_args():
type=int,
default=32,
help=(
- "The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_length` is passed."
+ "The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer"
+ " than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
@@ -110,8 +110,8 @@ def parse_args():
type=int,
default=30,
help=(
- "The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_length` is passed."
+ "The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences"
+ " longer than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
diff --git a/examples/research_projects/lxmert/demo.ipynb b/examples/research_projects/lxmert/demo.ipynb
index 55658ae111e6..e80865d0e2c8 100644
--- a/examples/research_projects/lxmert/demo.ipynb
+++ b/examples/research_projects/lxmert/demo.ipynb
@@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
- "#%pip install-r requirements.txt"
+ "# %pip install-r requirements.txt"
]
},
{
diff --git a/examples/research_projects/lxmert/modeling_frcnn.py b/examples/research_projects/lxmert/modeling_frcnn.py
index 39a0c6aea878..33c1133e9589 100644
--- a/examples/research_projects/lxmert/modeling_frcnn.py
+++ b/examples/research_projects/lxmert/modeling_frcnn.py
@@ -592,7 +592,7 @@ def __call__(self, match_quality_matrix):
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
- for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
low_high = (matched_vals >= low) & (matched_vals < high)
match_labels[low_high] = l
@@ -1037,9 +1037,9 @@ def make_stage(
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
- assert len(v) == num_blocks, (
- f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
- )
+ assert (
+ len(v) == num_blocks
+ ), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
@@ -1401,7 +1401,7 @@ def num_cell_anchors(self):
def grid_anchors(self, grid_sizes):
anchors = []
- for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -1708,10 +1708,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
- assert (
- from_tf
- ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
- pretrained_model_name_or_path + ".index"
+ assert from_tf, (
+ "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
+ .format(pretrained_model_name_or_path + ".index")
)
archive_file = pretrained_model_name_or_path + ".index"
else:
@@ -1797,26 +1796,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if len(unexpected_keys) > 0:
print(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
- f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
- f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
- f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
print(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
- f"and are newly initialized: {missing_keys}\n"
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
print(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
- f"If your task is similar to the task the model of the checkpoint was trained on, "
- f"you can already use {model.__class__.__name__} for predictions without further training."
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
)
if len(error_msgs) > 0:
raise RuntimeError(
diff --git a/examples/research_projects/lxmert/requirements.txt b/examples/research_projects/lxmert/requirements.txt
index fc3b85e16541..28a15ccb6ada 100644
--- a/examples/research_projects/lxmert/requirements.txt
+++ b/examples/research_projects/lxmert/requirements.txt
@@ -40,14 +40,14 @@ kiwisolver==1.2.0
lockfile==0.12.2
MarkupSafe==1.1.1
matplotlib==3.3.1
-mistune==0.8.4
+mistune==2.0.3
msgpack==0.6.2
nbclient==0.5.0
nbconvert==6.0.1
nbformat==5.0.7
nest-asyncio==1.4.0
-notebook==6.4.10
-numpy==1.21.0
+notebook==6.4.12
+numpy==1.22.0
opencv-python==4.4.0.42
packaging==20.3
pandas==1.1.2
diff --git a/examples/research_projects/lxmert/utils.py b/examples/research_projects/lxmert/utils.py
index 59ae11d025ad..8e830fb8359d 100644
--- a/examples/research_projects/lxmert/utils.py
+++ b/examples/research_projects/lxmert/utils.py
@@ -231,9 +231,10 @@ def compare(in_tensor):
n2 = out_tensor.numpy()[0]
print(n1.shape, n1[0, 0, :5])
print(n2.shape, n2[0, 0, :5])
- assert np.allclose(
- n1, n2, rtol=0.01, atol=0.1
- ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
+ assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
+ f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
+ " element-wise mismatch"
+ )
raise Exception("tensors are all good")
# Hugging face functions below
diff --git a/examples/research_projects/mlm_wwm/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py
index 51c05ab0b3de..f14ad5adfeff 100644
--- a/examples/research_projects/mlm_wwm/run_mlm_wwm.py
+++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py
@@ -61,8 +61,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -72,8 +73,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -97,8 +100,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -146,8 +151,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -160,8 +167,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
diff --git a/examples/research_projects/mm-imdb/run_mmimdb.py b/examples/research_projects/mm-imdb/run_mmimdb.py
index c73aec5c8747..9f12257a10a8 100644
--- a/examples/research_projects/mm-imdb/run_mmimdb.py
+++ b/examples/research_projects/mm-imdb/run_mmimdb.py
@@ -356,8 +356,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
@@ -423,8 +425,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
diff --git a/examples/research_projects/movement-pruning/bertarize.py b/examples/research_projects/movement-pruning/bertarize.py
index d1e2462a3044..623b46b94386 100644
--- a/examples/research_projects/movement-pruning/bertarize.py
+++ b/examples/research_projects/movement-pruning/bertarize.py
@@ -103,15 +103,20 @@ def main(args):
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
type=str,
required=True,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)"
+ ),
)
parser.add_argument(
"--threshold",
type=float,
required=False,
- help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
- "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
- "Not needed for `l0`",
+ help=(
+ "For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
+ "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
+ "Not needed for `l0`"
+ ),
)
parser.add_argument(
"--model_name_or_path",
diff --git a/examples/research_projects/movement-pruning/counts_parameters.py b/examples/research_projects/movement-pruning/counts_parameters.py
index 0dddfaaa277d..0aec3766b3f9 100644
--- a/examples/research_projects/movement-pruning/counts_parameters.py
+++ b/examples/research_projects/movement-pruning/counts_parameters.py
@@ -70,15 +70,20 @@ def main(args):
choices=["l0", "topK", "sigmoied_threshold"],
type=str,
required=True,
- help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
+ help=(
+ "Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement"
+ " pruning)"
+ ),
)
parser.add_argument(
"--threshold",
type=float,
required=False,
- help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
- "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
- "Not needed for `l0`",
+ help=(
+ "For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
+ "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
+ "Not needed for `l0`"
+ ),
)
parser.add_argument(
"--serialization_dir",
diff --git a/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py b/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
index 771d2078d066..4228050fe123 100644
--- a/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
+++ b/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
@@ -80,8 +80,8 @@ def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
- "The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
+ % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
diff --git a/examples/research_projects/movement-pruning/masked_run_glue.py b/examples/research_projects/movement-pruning/masked_run_glue.py
index 57f795945b1e..e81cf9209c88 100644
--- a/examples/research_projects/movement-pruning/masked_run_glue.py
+++ b/examples/research_projects/movement-pruning/masked_run_glue.py
@@ -622,8 +622,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -669,22 +671,29 @@ def main():
"--initial_warmup",
default=1,
type=int,
- help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
- "at its `initial_threshold` value (sparsity schedule).",
+ help=(
+ "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
+ "at its `initial_threshold` value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--final_warmup",
default=2,
type=int,
- help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
- "at its final_threshold value (sparsity schedule).",
+ help=(
+ "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
+ "at its final_threshold value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--pruning_method",
default="topK",
type=str,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)."
+ ),
)
parser.add_argument(
"--mask_init",
@@ -717,7 +726,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -787,8 +799,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -805,7 +819,8 @@ def main():
and not args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to"
+ " overcome."
)
# Setup CUDA, GPU & distributed training
diff --git a/examples/research_projects/movement-pruning/masked_run_squad.py b/examples/research_projects/movement-pruning/masked_run_squad.py
index f1d065f1f46b..1bd501eda514 100644
--- a/examples/research_projects/movement-pruning/masked_run_squad.py
+++ b/examples/research_projects/movement-pruning/masked_run_squad.py
@@ -737,8 +737,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -750,8 +752,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -785,22 +789,29 @@ def main():
"--initial_warmup",
default=1,
type=int,
- help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
- "at its `initial_threshold` value (sparsity schedule).",
+ help=(
+ "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
+ "at its `initial_threshold` value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--final_warmup",
default=2,
type=int,
- help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
- "at its final_threshold value (sparsity schedule).",
+ help=(
+ "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
+ "at its final_threshold value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--pruning_method",
default="topK",
type=str,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)."
+ ),
)
parser.add_argument(
"--mask_init",
@@ -833,7 +844,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -883,20 +897,27 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument(
"--lang_id",
default=0,
type=int,
- help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
+ help=(
+ "language id of input for language-specific xlm models (see"
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
+ ),
)
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
@@ -925,8 +946,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py b/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
index 58ee49a1b680..6db6842968a5 100644
--- a/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
+++ b/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
@@ -392,13 +392,14 @@ def init(
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
+ " one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
- f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
- f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
def hypo_len(self, hypo_idx: int):
@@ -508,7 +509,8 @@ def process(
if beam_idx < self.group_size:
raise ValueError(
- f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
diff --git a/examples/research_projects/onnx/summarization/run_onnx_exporter.py b/examples/research_projects/onnx/summarization/run_onnx_exporter.py
index 2a62ca9f704d..5d751ace8eee 100644
--- a/examples/research_projects/onnx/summarization/run_onnx_exporter.py
+++ b/examples/research_projects/onnx/summarization/run_onnx_exporter.py
@@ -53,14 +53,16 @@ def parse_args():
"--max_length",
type=int,
default=5,
- help=("The maximum total input sequence length after tokenization."),
+ help="The maximum total input sequence length after tokenization.",
)
parser.add_argument(
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--model_name_or_path",
diff --git a/examples/research_projects/performer/modeling_flax_performer_utils.py b/examples/research_projects/performer/modeling_flax_performer_utils.py
index abd42ec3d986..915e2fa23dd9 100644
--- a/examples/research_projects/performer/modeling_flax_performer_utils.py
+++ b/examples/research_projects/performer/modeling_flax_performer_utils.py
@@ -535,7 +535,7 @@ def dot_product_attention(
assert key.ndim == value.ndim
for ax in axis:
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
- raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
+ raise ValueError("Attention axis must be between the batch axis and the last-two axes.")
n = key.ndim
# Constructing projection tensor.
diff --git a/examples/research_projects/performer/run_mlm_performer.py b/examples/research_projects/performer/run_mlm_performer.py
index 34aa75f8a9d6..8e8fe917653e 100644
--- a/examples/research_projects/performer/run_mlm_performer.py
+++ b/examples/research_projects/performer/run_mlm_performer.py
@@ -98,8 +98,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
performer: bool = field(
@@ -159,8 +160,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -173,8 +176,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
@@ -428,7 +433,7 @@ def eval_step(params, batch):
return compute_metrics(logits, targets, token_mask)
-def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
+def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
nb_samples = len(samples_idx)
samples_to_remove = nb_samples % batch_size
@@ -634,7 +639,8 @@ def tokenize_function(examples):
# Generate an epoch by shuffling sampling indices from the train dataset
nb_training_samples = len(tokenized_datasets["train"])
- training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
+ # Avoid using jax.numpy here in case of TPU training
+ training_samples_idx = np.random.permutation(np.arange(nb_training_samples))
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
# Gather the indexes for creating the batch and do a training step
@@ -653,7 +659,8 @@ def tokenize_function(examples):
# ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["validation"])
- eval_samples_idx = jnp.arange(nb_eval_samples)
+ # Avoid using jax.numpy here in case of TPU training
+ eval_samples_idx = np.arange(nb_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
diff --git a/examples/research_projects/pplm/run_pplm_discrim_train.py b/examples/research_projects/pplm/run_pplm_discrim_train.py
index ec8cd9b9facd..6a7351d9e6a6 100644
--- a/examples/research_projects/pplm/run_pplm_discrim_train.py
+++ b/examples/research_projects/pplm/run_pplm_discrim_train.py
@@ -175,8 +175,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
test_loss /= len(data_loader.dataset)
print(
- "Performance on test set: "
- "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
+ "Performance on test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
)
)
@@ -309,7 +308,7 @@ def train_discriminator(
x.append(seq)
y.append(d["label"])
except Exception:
- print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
+ print("Error evaluating / tokenizing line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
@@ -349,7 +348,7 @@ def train_discriminator(
x.append(seq)
y.append(int(np.sum(d["label"]) > 0))
except Exception:
- print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
+ print("Error evaluating / tokenizing line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
@@ -370,7 +369,7 @@ def train_discriminator(
# class \t text
if dataset_fp is None:
- raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
+ raise ValueError("When generic dataset is selected, dataset_fp needs to be specified aswell.")
classes = set()
with open(dataset_fp) as f:
@@ -490,15 +489,17 @@ def train_discriminator(
type=str,
default="SST",
choices=("SST", "clickbait", "toxic", "generic"),
- help="dataset to train the discriminator on."
- "In case of generic, the dataset is expected"
- "to be a TSBV file with structure: class \\t text",
+ help=(
+ "dataset to train the discriminator on."
+ "In case of generic, the dataset is expected"
+ "to be a TSBV file with structure: class \\t text"
+ ),
)
parser.add_argument(
"--dataset_fp",
type=str,
default="",
- help="File path of the dataset to use. " "Needed only in case of generic datadset",
+ help="File path of the dataset to use. Needed only in case of generic datadset",
)
parser.add_argument(
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
diff --git a/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
index 4a618ed77cd5..2a0899630395 100755
--- a/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
+++ b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
@@ -87,8 +87,10 @@
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -109,8 +111,10 @@
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
diff --git a/examples/research_projects/quantization-qdqbert/quant_trainer.py b/examples/research_projects/quantization-qdqbert/quant_trainer.py
index b9fbad8a4a82..ce1ecb6c51fe 100755
--- a/examples/research_projects/quantization-qdqbert/quant_trainer.py
+++ b/examples/research_projects/quantization-qdqbert/quant_trainer.py
@@ -51,8 +51,10 @@ def add_arguments(parser):
group.add_argument(
"--recalibrate-weights",
action="store_true",
- help="recalibrate weight amaxes by taking the max of the weights."
- " amaxes will be computed with the current quantization granularity (axis).",
+ help=(
+ "recalibrate weight amaxes by taking the max of the weights."
+ " amaxes will be computed with the current quantization granularity (axis)."
+ ),
)
diff --git a/examples/research_projects/quantization-qdqbert/run_quant_qa.py b/examples/research_projects/quantization-qdqbert/run_quant_qa.py
index 36bfb45c8ffc..5008197b8b84 100755
--- a/examples/research_projects/quantization-qdqbert/run_quant_qa.py
+++ b/examples/research_projects/quantization-qdqbert/run_quant_qa.py
@@ -83,8 +83,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
do_calib: bool = field(default=False, metadata={"help": "Whether to run calibration of quantization ranges."})
@@ -126,37 +128,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -165,9 +176,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -181,8 +194,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -328,9 +343,9 @@ def main():
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Preprocessing the datasets.
diff --git a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py
index b23edb6d5185..ef0d93a7e357 100644
--- a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py
+++ b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py
@@ -30,7 +30,7 @@
logger = logging.getLogger(__name__)
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
diff --git a/examples/research_projects/rag-end2end-retriever/README.md b/examples/research_projects/rag-end2end-retriever/README.md
index 7cee2f1ea09c..9bff4e8c29ab 100644
--- a/examples/research_projects/rag-end2end-retriever/README.md
+++ b/examples/research_projects/rag-end2end-retriever/README.md
@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
+# Latest Update
+
+ā ļø Updated the rag-end2end-retriever to be compatible with PL==1.6.4 and RAY==1.13.0 (latest versions to the date 2022-June-11)
+
# Note
ā ļø This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability
@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol
# Testing
The following two bash scripts can be used to quickly test the implementation.
-1. sh ./test_run/test_rag_new_features.sh
- - Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- - This is sufficient to check the model's ability to use the set functions correctly.
-2. sh ./test_run/test_finetune.sh script
+1. sh ./test_run/test_finetune.sh script
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
+ - Please read the comments in the test_finetune.sh file.
+2. sh ./test_run/test_rag_new_features.sh
+ - Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
+ - This is sufficient to check the model's ability to use the set functions correctly.
+
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
diff --git a/examples/research_projects/rag-end2end-retriever/callbacks_rag.py b/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
index 55fc9655dff7..5f18244a7aa4 100644
--- a/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
@@ -31,7 +31,8 @@ def get_checkpoint_callback(output_dir, metric):
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
+ " function."
)
checkpoint_callback = ModelCheckpoint(
@@ -40,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
- every_n_val_epochs=1, # works only with PL > 1.3
+ every_n_epochs=1, # works only with PL > 1.3
)
return checkpoint_callback
diff --git a/examples/research_projects/rag-end2end-retriever/eval_rag.py b/examples/research_projects/rag-end2end-retriever/eval_rag.py
index 05f78c3d6cdf..a8e7abbca6ce 100644
--- a/examples/research_projects/rag-end2end-retriever/eval_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/eval_rag.py
@@ -146,7 +146,10 @@ def get_args():
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
- help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--index_name",
@@ -174,7 +177,10 @@ def get_args():
choices=["e2e", "retrieval"],
default="e2e",
type=str,
- help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
+ help=(
+ "Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
+ " precision@k."
+ ),
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
@@ -196,9 +202,11 @@ def get_args():
default="qa",
type=str,
choices=["qa", "ans"],
- help="Format of the gold data file"
- "qa - a single line in the following format: question [tab] answer_list"
- "ans - a single line of the gold file contains the expected answer string",
+ help=(
+ "Format of the gold data file"
+ "qa - a single line in the following format: question [tab] answer_list"
+ "ans - a single line of the gold file contains the expected answer string"
+ ),
)
parser.add_argument(
"--predictions_path",
diff --git a/examples/research_projects/rag-end2end-retriever/finetune_rag.py b/examples/research_projects/rag-end2end-retriever/finetune_rag.py
index 96cbc0f7c530..1229870e63c6 100644
--- a/examples/research_projects/rag-end2end-retriever/finetune_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/finetune_rag.py
@@ -350,6 +350,7 @@ def training_step(self, batch, batch_idx) -> Dict:
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset")
+ # To Do (@Aaron) : Useful in the future dynamic memory implementation.
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
@@ -360,10 +361,7 @@ def training_step(self, batch, batch_idx) -> Dict:
isEmUpdateBusy = False
isAddIndexBusy = False
-
- self.trainer.accelerator_connector.accelerator.barrier(
- "barrier"
- ) # waint untill the index and kb get re-initialized.
+ self.trainer.strategy.barrier("barrier")
loss_tensors = self._step(batch)
@@ -515,29 +513,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -555,7 +561,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
@@ -564,7 +573,10 @@ def add_model_specific_args(parser, root_dir):
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
- help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: sequence or token, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--context_encoder_name",
@@ -590,7 +602,10 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--gpu_order",
type=str,
- help="order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode process. I do not have many GPUs :)",
+ help=(
+ "order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode"
+ " process. I do not have many GPUs :)"
+ ),
)
parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
@@ -602,39 +617,53 @@ def add_retriever_specific_args(parser):
"--index_name",
type=str,
default=None,
- help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
+ help=(
+ "Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
+ " for a local index, or 'legacy' for the orignal one)"
+ ),
)
parser.add_argument(
"--passages_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset"),
- help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--index_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset_hnsw_index.faiss"),
- help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="ray",
- help="What implementation to use for distributed retriever? If "
- "pytorch is selected, the index is loaded on training "
- "worker 0, and torch.distributed is used to handle "
- "communication between training worker 0, and the other "
- "training workers. If ray is selected, the Ray library is "
- "used to create load the index on separate processes, "
- "and Ray handles the communication between the training "
- "workers and the retrieval actors.",
+ help=(
+ "What implementation to use for distributed retriever? If "
+ "pytorch is selected, the index is loaded on training "
+ "worker 0, and torch.distributed is used to handle "
+ "communication between training worker 0, and the other "
+ "training workers. If ray is selected, the Ray library is "
+ "used to create load the index on separate processes, "
+ "and Ray handles the communication between the training "
+ "workers and the retrieval actors."
+ ),
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
- help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Whether to use the dummy version of the dataset index. More info about custom indexes in the"
+ " RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
return parser
@@ -645,18 +674,22 @@ def add_ray_specific_args(parser):
"--ray-address",
default="auto",
type=str,
- help="The address of the Ray cluster to connect to. If not "
- "specified, Ray will attempt to automatically detect the "
- "cluster. Has no effect if pytorch is used as the distributed "
- "retriever.",
+ help=(
+ "The address of the Ray cluster to connect to. If not "
+ "specified, Ray will attempt to automatically detect the "
+ "cluster. Has no effect if pytorch is used as the distributed "
+ "retriever."
+ ),
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
- help="The number of retrieval actors to use when Ray is selected"
- "for the distributed retriever. Has no effect when "
- "distributed_retriever is set to pytorch.",
+ help=(
+ "The number of retrieval actors to use when Ray is selected"
+ "for the distributed retriever. Has no effect when "
+ "distributed_retriever is set to pytorch."
+ ),
)
return parser
@@ -686,10 +719,10 @@ def main(args=None, model=None) -> GenerativeQAModule:
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
- raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
+ raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
- ray.init(address=args.ray_address)
+ ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
diff --git a/examples/research_projects/rag-end2end-retriever/lightning_base.py b/examples/research_projects/rag-end2end-retriever/lightning_base.py
index 1df0fae58498..84842944059a 100644
--- a/examples/research_projects/rag-end2end-retriever/lightning_base.py
+++ b/examples/research_projects/rag-end2end-retriever/lightning_base.py
@@ -5,7 +5,6 @@
from typing import Any, Dict
import pytorch_lightning as pl
-from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
@@ -333,8 +332,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
@@ -384,24 +385,22 @@ def generic_train(
train_params = {}
- # TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
- train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
- train_params["accelerator"] = "ddp"
+ train_params["accelerator"] = "auto"
+ train_params["strategy"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
- # train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
- train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
+ train_params["profiler"] = None
+ train_params["devices"] = "auto"
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
logger=logger,
- plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
val_check_interval=1,
num_sanity_val_steps=2,
**train_params,
@@ -410,6 +409,6 @@ def generic_train(
if args.do_train:
trainer.fit(model)
- # else:
- # print("RAG modeling tests with new set functions successfuly executed!")
+ else:
+ print("RAG modeling tests with new set functions successfuly executed!")
return trainer
diff --git a/examples/research_projects/rag-end2end-retriever/requirements.txt b/examples/research_projects/rag-end2end-retriever/requirements.txt
index aca89c78e88c..32025229d074 100644
--- a/examples/research_projects/rag-end2end-retriever/requirements.txt
+++ b/examples/research_projects/rag-end2end-retriever/requirements.txt
@@ -1,7 +1,7 @@
-faiss-cpu >= 1.7.0
-datasets >= 1.6.2
-psutil >= 5.7.0
-torch >= 1.4.0
-pytorch-lightning
+faiss-cpu >= 1.7.2
+datasets
+psutil >= 5.9.1
+torch >= 1.11.0
+pytorch-lightning == 1.6.4
nvidia-ml-py3 == 7.352.0
-ray >= 1.3.0
+ray >= 1.13.0
\ No newline at end of file
diff --git a/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh b/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
index bbf69b05380e..c44d110d2004 100755
--- a/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
+++ b/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
@@ -44,11 +44,14 @@ python finetune_rag.py \
--num_retrieval_workers 4 \
--index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
- --index_gpus 1 \
- --gpu_order [6,7,8,9,0,1,2,3,5,4] \
+ --index_gpus 2 \
+ --gpu_order [2,3,4,5,6,7,8,9,0,1] \
--indexing_freq 5
# Stop the Ray cluster.
ray stop
+
+#CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9,0,1 sh ./test_run/test_finetune.sh
+#Make sure --gpu_order is same.
\ No newline at end of file
diff --git a/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py b/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
index 213aa8d882fc..432111a2784c 100644
--- a/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
+++ b/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
@@ -121,7 +121,10 @@ class RagExampleArguments:
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
- "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
+ "help": (
+ "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
+ " 'facebook/dpr-ctx_encoder-multiset-base'"
+ )
},
)
output_dir: Optional[str] = field(
@@ -155,7 +158,9 @@ class IndexHnswArguments:
m: int = field(
default=128,
metadata={
- "help": "The number of bi-directional links created for every new element during the HNSW index construction."
+ "help": (
+ "The number of bi-directional links created for every new element during the HNSW index construction."
+ )
},
)
diff --git a/examples/research_projects/rag/callbacks_rag.py b/examples/research_projects/rag/callbacks_rag.py
index a2d87f82247c..af1595b08efd 100644
--- a/examples/research_projects/rag/callbacks_rag.py
+++ b/examples/research_projects/rag/callbacks_rag.py
@@ -29,7 +29,8 @@ def get_checkpoint_callback(output_dir, metric):
exp = "{val_avg_em:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
+ " function."
)
checkpoint_callback = ModelCheckpoint(
diff --git a/examples/research_projects/rag/consolidate_rag_checkpoint.py b/examples/research_projects/rag/consolidate_rag_checkpoint.py
index b9ed7ec0f811..39ba7e91f6c3 100644
--- a/examples/research_projects/rag/consolidate_rag_checkpoint.py
+++ b/examples/research_projects/rag/consolidate_rag_checkpoint.py
@@ -80,7 +80,10 @@ def consolidate(
parser.add_argument(
"--config_name_or_path",
type=str,
- help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
+ help=(
+ "Identifier of the model config to use, if not provided, resolves to a base config for a given"
+ " ``model_type``"
+ ),
)
args = parser.parse_args()
diff --git a/examples/research_projects/rag/eval_rag.py b/examples/research_projects/rag/eval_rag.py
index 05f78c3d6cdf..a8e7abbca6ce 100644
--- a/examples/research_projects/rag/eval_rag.py
+++ b/examples/research_projects/rag/eval_rag.py
@@ -146,7 +146,10 @@ def get_args():
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
- help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--index_name",
@@ -174,7 +177,10 @@ def get_args():
choices=["e2e", "retrieval"],
default="e2e",
type=str,
- help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
+ help=(
+ "Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
+ " precision@k."
+ ),
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
@@ -196,9 +202,11 @@ def get_args():
default="qa",
type=str,
choices=["qa", "ans"],
- help="Format of the gold data file"
- "qa - a single line in the following format: question [tab] answer_list"
- "ans - a single line of the gold file contains the expected answer string",
+ help=(
+ "Format of the gold data file"
+ "qa - a single line in the following format: question [tab] answer_list"
+ "ans - a single line of the gold file contains the expected answer string"
+ ),
)
parser.add_argument(
"--predictions_path",
diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py
index 2fd4ef7659c5..f5cef614e2d9 100644
--- a/examples/research_projects/rag/finetune_rag.py
+++ b/examples/research_projects/rag/finetune_rag.py
@@ -383,29 +383,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -423,7 +431,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
@@ -432,7 +443,10 @@ def add_model_specific_args(parser, root_dir):
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
- help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: sequence or token, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
return parser
@@ -442,39 +456,53 @@ def add_retriever_specific_args(parser):
"--index_name",
type=str,
default=None,
- help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
+ help=(
+ "Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
+ " for a local index, or 'legacy' for the orignal one)"
+ ),
)
parser.add_argument(
"--passages_path",
type=str,
default=None,
- help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--index_path",
type=str,
default=None,
- help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="pytorch",
- help="What implementation to use for distributed retriever? If "
- "pytorch is selected, the index is loaded on training "
- "worker 0, and torch.distributed is used to handle "
- "communication between training worker 0, and the other "
- "training workers. If ray is selected, the Ray library is "
- "used to create load the index on separate processes, "
- "and Ray handles the communication between the training "
- "workers and the retrieval actors.",
+ help=(
+ "What implementation to use for distributed retriever? If "
+ "pytorch is selected, the index is loaded on training "
+ "worker 0, and torch.distributed is used to handle "
+ "communication between training worker 0, and the other "
+ "training workers. If ray is selected, the Ray library is "
+ "used to create load the index on separate processes, "
+ "and Ray handles the communication between the training "
+ "workers and the retrieval actors."
+ ),
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
- help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Whether to use the dummy version of the dataset index. More info about custom indexes in the"
+ " RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
return parser
@@ -485,18 +513,22 @@ def add_ray_specific_args(parser):
"--ray-address",
default="auto",
type=str,
- help="The address of the Ray cluster to connect to. If not "
- "specified, Ray will attempt to automatically detect the "
- "cluster. Has no effect if pytorch is used as the distributed "
- "retriever.",
+ help=(
+ "The address of the Ray cluster to connect to. If not "
+ "specified, Ray will attempt to automatically detect the "
+ "cluster. Has no effect if pytorch is used as the distributed "
+ "retriever."
+ ),
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
- help="The number of retrieval actors to use when Ray is selected"
- "for the distributed retriever. Has no effect when "
- "distributed_retriever is set to pytorch.",
+ help=(
+ "The number of retrieval actors to use when Ray is selected"
+ "for the distributed retriever. Has no effect when "
+ "distributed_retriever is set to pytorch."
+ ),
)
return parser
@@ -514,7 +546,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
- raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
+ raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address, namespace="rag")
diff --git a/examples/research_projects/rag/lightning_base.py b/examples/research_projects/rag/lightning_base.py
index 1e0f67627e7c..77830a4760ad 100644
--- a/examples/research_projects/rag/lightning_base.py
+++ b/examples/research_projects/rag/lightning_base.py
@@ -321,8 +321,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/research_projects/rag/use_own_knowledge_dataset.py b/examples/research_projects/rag/use_own_knowledge_dataset.py
index 269765caab86..dc08f508228a 100644
--- a/examples/research_projects/rag/use_own_knowledge_dataset.py
+++ b/examples/research_projects/rag/use_own_knowledge_dataset.py
@@ -154,7 +154,10 @@ class RagExampleArguments:
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
- "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
+ "help": (
+ "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
+ " 'facebook/dpr-ctx_encoder-multiset-base'"
+ )
},
)
output_dir: Optional[str] = field(
@@ -188,7 +191,9 @@ class IndexHnswArguments:
m: int = field(
default=128,
metadata={
- "help": "The number of bi-directional links created for every new element during the HNSW index construction."
+ "help": (
+ "The number of bi-directional links created for every new element during the HNSW index construction."
+ )
},
)
diff --git a/examples/research_projects/robust-speech-event/eval.py b/examples/research_projects/robust-speech-event/eval.py
index 53cd244daf75..32e3d1f2c729 100755
--- a/examples/research_projects/robust-speech-event/eval.py
+++ b/examples/research_projects/robust-speech-event/eval.py
@@ -24,7 +24,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
# print & log results
- result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
+ result_str = f"WER: {wer_result}\nCER: {cer_result}"
print(result_str)
with open(f"{dataset_id}_eval_results.txt", "w") as f:
diff --git a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
index 2317367e7cc3..5294e6a4a9ae 100755
--- a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
+++ b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
@@ -103,9 +103,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -115,8 +117,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -175,15 +180,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -197,7 +206,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -206,17 +218,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -234,10 +250,12 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -286,13 +304,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
return_tensors="pt",
)
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -406,9 +423,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.text_column_name not in raw_datasets["train"].column_names:
@@ -743,7 +760,10 @@ def compute_metrics(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
index 9e69178088f6..8add8fd20a72 100644
--- a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
+++ b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
@@ -102,9 +102,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -114,8 +116,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -147,8 +152,10 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train+validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'train+validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to "
+ "'train+validation'"
+ )
},
)
eval_split_name: str = field(
@@ -175,22 +182,28 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
shuffle_buffer_size: Optional[int] = field(
default=500,
metadata={
- "help": "The number of streamed examples to download before shuffling them. The large the buffer, "
- "the closer it is to real offline shuffling."
+ "help": (
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
+ "the closer it is to real offline shuffling."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -208,26 +221,32 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -282,13 +301,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
return_tensors="pt",
)
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -393,9 +411,9 @@ def load_streaming_dataset(split, sampling_rate, **kwargs):
if data_args.audio_column_name not in raw_column_names["train"]:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_column_names['train'])}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_column_names['train'])}."
)
if data_args.text_column_name not in raw_column_names["train"]:
@@ -641,7 +659,10 @@ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/research_projects/self-training-text-classification/finetuning.py b/examples/research_projects/self-training-text-classification/finetuning.py
index 8ad92359b619..eeb0a285dff9 100644
--- a/examples/research_projects/self-training-text-classification/finetuning.py
+++ b/examples/research_projects/self-training-text-classification/finetuning.py
@@ -100,15 +100,19 @@ class FTDataArguments:
max_length: Optional[int] = dataclasses.field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: Optional[bool] = dataclasses.field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
@@ -147,7 +151,10 @@ class FTTrainingArguments:
weight_decay: Optional[float] = dataclasses.field(
default=0.0,
metadata={
- "help": "The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] optimizer."
+ "help": (
+ "The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in"
+ " [`AdamW`] optimizer."
+ )
},
)
learning_rate: Optional[float] = dataclasses.field(
@@ -157,13 +164,18 @@ class FTTrainingArguments:
gradient_accumulation_steps: Optional[int] = dataclasses.field(
default=1,
metadata={
- "help": "Number of updates steps to accumulate the gradients for, before performing a backward/update pass."
+ "help": (
+ "Number of updates steps to accumulate the gradients for, before performing a backward/update pass."
+ )
},
)
max_steps: Optional[int] = dataclasses.field(
default=-1,
metadata={
- "help": "If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`."
+ "help": (
+ "If set to a positive number, the total number of training steps to perform. Overrides"
+ " `num_train_epochs`."
+ )
},
)
lr_scheduler_type: Optional[str] = dataclasses.field(
@@ -172,7 +184,10 @@ class FTTrainingArguments:
warmup_steps: Optional[int] = dataclasses.field(
default=1,
metadata={
- "help": "Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`."
+ "help": (
+ "Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of"
+ " `warmup_ratio`."
+ )
},
)
evaluation_strategy: Optional[str] = dataclasses.field(
diff --git a/examples/research_projects/seq2seq-distillation/callbacks.py b/examples/research_projects/seq2seq-distillation/callbacks.py
index 388b6d53ddd3..6f6ed5dd58ac 100644
--- a/examples/research_projects/seq2seq-distillation/callbacks.py
+++ b/examples/research_projects/seq2seq-distillation/callbacks.py
@@ -93,7 +93,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to"
+ " this function."
)
checkpoint_callback = ModelCheckpoint(
diff --git a/examples/research_projects/seq2seq-distillation/distillation.py b/examples/research_projects/seq2seq-distillation/distillation.py
index 1f9106f0c0a7..5a403be8d562 100755
--- a/examples/research_projects/seq2seq-distillation/distillation.py
+++ b/examples/research_projects/seq2seq-distillation/distillation.py
@@ -52,9 +52,10 @@ def __init__(self, hparams):
student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config)
- assert (
- student.config.model_type == teacher.config.model_type
- ), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
+ assert student.config.model_type == teacher.config.model_type, (
+ f"teacher, student model types should be the same, got {student.config.model_type} !="
+ f" {teacher.config.model_type}"
+ )
if student.config.model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)
diff --git a/examples/research_projects/seq2seq-distillation/finetune.py b/examples/research_projects/seq2seq-distillation/finetune.py
index 5874509377aa..c20b361d5836 100755
--- a/examples/research_projects/seq2seq-distillation/finetune.py
+++ b/examples/research_projects/seq2seq-distillation/finetune.py
@@ -303,29 +303,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=1024,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
@@ -353,7 +361,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
return parser
diff --git a/examples/research_projects/seq2seq-distillation/lightning_base.py b/examples/research_projects/seq2seq-distillation/lightning_base.py
index b7f53076e3bc..b3104a25a8b1 100644
--- a/examples/research_projects/seq2seq-distillation/lightning_base.py
+++ b/examples/research_projects/seq2seq-distillation/lightning_base.py
@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/research_projects/seq2seq-distillation/make_student.py b/examples/research_projects/seq2seq-distillation/make_student.py
index 8d70292d0e5a..a4021505b998 100644
--- a/examples/research_projects/seq2seq-distillation/make_student.py
+++ b/examples/research_projects/seq2seq-distillation/make_student.py
@@ -58,7 +58,8 @@ def pick_layers_to_copy(n_student, n_teacher):
except KeyError:
if n_student != n_teacher:
warnings.warn(
- f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
+ f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
+ f" {n_student}"
)
return list(range(n_student))
@@ -144,7 +145,8 @@ def create_student_by_copying_alternating_layers(
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
- f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
+ f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to"
+ f" {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy
diff --git a/examples/research_projects/seq2seq-distillation/run_eval.py b/examples/research_projects/seq2seq-distillation/run_eval.py
index de752c7df189..3f685884e8e8 100755
--- a/examples/research_projects/seq2seq-distillation/run_eval.py
+++ b/examples/research_projects/seq2seq-distillation/run_eval.py
@@ -108,7 +108,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
- help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
+ help=(
+ "use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
+ " lang=en-ru. If no value is passed, the current datetime string will be used."
+ ),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
diff --git a/examples/research_projects/tapex/run_tabfact_with_tapex.py b/examples/research_projects/tapex/run_tabfact_with_tapex.py
index 0ed573ad9c1a..23d094f8992a 100644
--- a/examples/research_projects/tapex/run_tabfact_with_tapex.py
+++ b/examples/research_projects/tapex/run_tabfact_with_tapex.py
@@ -77,8 +77,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -87,29 +89,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -164,8 +174,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
diff --git a/examples/research_projects/tapex/run_wikisql_with_tapex.py b/examples/research_projects/tapex/run_wikisql_with_tapex.py
index 594c83cb6be5..1d402fa7e8f0 100644
--- a/examples/research_projects/tapex/run_wikisql_with_tapex.py
+++ b/examples/research_projects/tapex/run_wikisql_with_tapex.py
@@ -82,8 +82,10 @@ class ModelArguments:
tokenizer_name: Optional[str] = field(
default=None,
metadata={
- "help": "Pretrained tokenizer name or path if not the same as model_name. "
- "By default we use BART-large tokenizer for TAPEX-large."
+ "help": (
+ "Pretrained tokenizer name or path if not the same as model_name. "
+ "By default we use BART-large tokenizer for TAPEX-large."
+ )
},
)
cache_dir: Optional[str] = field(
@@ -101,8 +103,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -125,14 +129,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -145,60 +150,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -416,13 +437,12 @@ def _convert_table_types(_table):
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
)
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(
- answer=[", ".join(answer) for answer in answers],
- max_length=max_target_length,
- padding=padding,
- truncation=True,
- )
+ labels = tokenizer(
+ answer=[", ".join(answer) for answer in answers],
+ max_length=max_target_length,
+ padding=padding,
+ truncation=True,
+ )
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
diff --git a/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py b/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
index 4398309566a8..6f93f9b51669 100644
--- a/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
+++ b/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
@@ -80,8 +80,10 @@ class ModelArguments:
tokenizer_name: Optional[str] = field(
default=None,
metadata={
- "help": "Pretrained tokenizer name or path if not the same as model_name. "
- "By default we use BART-large tokenizer for TAPEX-large."
+ "help": (
+ "Pretrained tokenizer name or path if not the same as model_name. "
+ "By default we use BART-large tokenizer for TAPEX-large."
+ )
},
)
cache_dir: Optional[str] = field(
@@ -99,8 +101,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -123,14 +127,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -143,60 +148,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -392,13 +413,12 @@ def preprocess_tableqa_function(examples, is_training=False):
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
)
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(
- answer=[", ".join(answer) for answer in answers],
- max_length=max_target_length,
- padding=padding,
- truncation=True,
- )
+ labels = tokenizer(
+ answer=[", ".join(answer) for answer in answers],
+ max_length=max_target_length,
+ padding=padding,
+ truncation=True,
+ )
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
diff --git a/examples/research_projects/tapex/wikisql_utils.py b/examples/research_projects/tapex/wikisql_utils.py
index 9147fdc882e4..3028e81ad481 100644
--- a/examples/research_projects/tapex/wikisql_utils.py
+++ b/examples/research_projects/tapex/wikisql_utils.py
@@ -23,8 +23,6 @@
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
from typing import Any, List, Text
-import six
-
EMPTY_ANSWER = "none"
EMPTY_ANSWER_AGG = "none"
@@ -49,7 +47,7 @@ def convert_to_float(value):
return value
if isinstance(value, int):
return float(value)
- if not isinstance(value, six.string_types):
+ if not isinstance(value, str):
raise ValueError("Argument value is not a string. Can't parse it as float")
sanitized = value
diff --git a/examples/research_projects/visual_bert/demo.ipynb b/examples/research_projects/visual_bert/demo.ipynb
index a025e419a3c6..14a65ce3df33 100644
--- a/examples/research_projects/visual_bert/demo.ipynb
+++ b/examples/research_projects/visual_bert/demo.ipynb
@@ -4,7 +4,7 @@
"cell_type": "code",
"execution_count": 1,
"source": [
- "#%pip install-r requirements.txt"
+ "# %pip install-r requirements.txt"
],
"outputs": [],
"metadata": {}
diff --git a/examples/research_projects/visual_bert/modeling_frcnn.py b/examples/research_projects/visual_bert/modeling_frcnn.py
index 39a0c6aea878..33c1133e9589 100644
--- a/examples/research_projects/visual_bert/modeling_frcnn.py
+++ b/examples/research_projects/visual_bert/modeling_frcnn.py
@@ -592,7 +592,7 @@ def __call__(self, match_quality_matrix):
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
- for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
low_high = (matched_vals >= low) & (matched_vals < high)
match_labels[low_high] = l
@@ -1037,9 +1037,9 @@ def make_stage(
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
- assert len(v) == num_blocks, (
- f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
- )
+ assert (
+ len(v) == num_blocks
+ ), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
@@ -1401,7 +1401,7 @@ def num_cell_anchors(self):
def grid_anchors(self, grid_sizes):
anchors = []
- for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -1708,10 +1708,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
- assert (
- from_tf
- ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
- pretrained_model_name_or_path + ".index"
+ assert from_tf, (
+ "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
+ .format(pretrained_model_name_or_path + ".index")
)
archive_file = pretrained_model_name_or_path + ".index"
else:
@@ -1797,26 +1796,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if len(unexpected_keys) > 0:
print(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
- f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
- f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
- f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
print(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
- f"and are newly initialized: {missing_keys}\n"
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
print(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
- f"If your task is similar to the task the model of the checkpoint was trained on, "
- f"you can already use {model.__class__.__name__} for predictions without further training."
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
)
if len(error_msgs) > 0:
raise RuntimeError(
diff --git a/examples/research_projects/visual_bert/requirements.txt b/examples/research_projects/visual_bert/requirements.txt
index fc3b85e16541..28a15ccb6ada 100644
--- a/examples/research_projects/visual_bert/requirements.txt
+++ b/examples/research_projects/visual_bert/requirements.txt
@@ -40,14 +40,14 @@ kiwisolver==1.2.0
lockfile==0.12.2
MarkupSafe==1.1.1
matplotlib==3.3.1
-mistune==0.8.4
+mistune==2.0.3
msgpack==0.6.2
nbclient==0.5.0
nbconvert==6.0.1
nbformat==5.0.7
nest-asyncio==1.4.0
-notebook==6.4.10
-numpy==1.21.0
+notebook==6.4.12
+numpy==1.22.0
opencv-python==4.4.0.42
packaging==20.3
pandas==1.1.2
diff --git a/examples/research_projects/visual_bert/utils.py b/examples/research_projects/visual_bert/utils.py
index 59ae11d025ad..8e830fb8359d 100644
--- a/examples/research_projects/visual_bert/utils.py
+++ b/examples/research_projects/visual_bert/utils.py
@@ -231,9 +231,10 @@ def compare(in_tensor):
n2 = out_tensor.numpy()[0]
print(n1.shape, n1[0, 0, :5])
print(n2.shape, n2[0, 0, :5])
- assert np.allclose(
- n1, n2, rtol=0.01, atol=0.1
- ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
+ assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
+ f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
+ " element-wise mismatch"
+ )
raise Exception("tensors are all good")
# Hugging face functions below
diff --git a/examples/research_projects/wav2vec2/run_asr.py b/examples/research_projects/wav2vec2/run_asr.py
index 9b031cca1972..692aa39796a7 100755
--- a/examples/research_projects/wav2vec2/run_asr.py
+++ b/examples/research_projects/wav2vec2/run_asr.py
@@ -30,7 +30,7 @@
if is_apex_available():
from apex import amp
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
@@ -99,7 +99,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
target_text_column: Optional[str] = field(
@@ -121,7 +123,10 @@ class DataTrainingArguments:
orthography: Optional[str] = field(
default="librispeech",
metadata={
- "help": "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or 'buckwalter'."
+ "help": (
+ "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or"
+ " 'buckwalter'."
+ )
},
)
overwrite_cache: bool = field(
@@ -261,14 +266,13 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- max_length=self.max_length_labels,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ max_length=self.max_length_labels,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -392,11 +396,13 @@ def filter_by_max_duration(example):
val_dataset = val_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
if len(train_dataset) > old_train_size:
logger.warning(
- f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than {data_args.max_duration_in_seconds} second(s)."
+ f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than"
+ f" {data_args.max_duration_in_seconds} second(s)."
)
if len(val_dataset) > old_val_size:
logger.warning(
- f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than {data_args.max_duration_in_seconds} second(s)."
+ f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than"
+ f" {data_args.max_duration_in_seconds} second(s)."
)
logger.info(f"Split sizes: {len(train_dataset)} train and {len(val_dataset)} validation.")
@@ -412,9 +418,10 @@ def prepare_dataset(batch):
len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
- batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
- with processor.as_target_processor():
- batch["labels"] = processor(batch[data_args.target_text_column]).input_ids
+ processed_batch = processor(
+ audio=batch["speech"], text=batch[data_args.target_text_column], sampling_rate=batch["sampling_rate"][0]
+ )
+ batch.update(processed_batch)
return batch
train_dataset = train_dataset.map(
diff --git a/examples/research_projects/wav2vec2/run_common_voice.py b/examples/research_projects/wav2vec2/run_common_voice.py
index 5825c1feb10b..01a877a8092e 100644
--- a/examples/research_projects/wav2vec2/run_common_voice.py
+++ b/examples/research_projects/wav2vec2/run_common_voice.py
@@ -33,7 +33,7 @@
from apex import amp
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
@@ -79,9 +79,11 @@ class ModelArguments:
mask_time_prob: Optional[float] = field(
default=0.05,
metadata={
- "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
+ "help": (
+ "Propability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
+ )
},
)
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
@@ -116,15 +118,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: List[str] = list_field(
@@ -179,14 +185,13 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- max_length=self.max_length_labels,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ max_length=self.max_length_labels,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -408,10 +413,11 @@ def prepare_dataset(batch):
assert (
len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
- batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
- # Setup the processor for targets
- with processor.as_target_processor():
- batch["labels"] = processor(batch["target_text"]).input_ids
+
+ processed_batch = processor(
+ audio=batch["speech"], text=batch["target_text"], sampling_rate=batch["sampling_rate"][0]
+ )
+ batch.update(processed_batch)
return batch
train_dataset = train_dataset.map(
diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py
index 248f32443f04..8e0801429e61 100755
--- a/examples/research_projects/wav2vec2/run_pretrain.py
+++ b/examples/research_projects/wav2vec2/run_pretrain.py
@@ -26,7 +26,7 @@
if is_apex_available():
from apex import amp
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast
@@ -104,7 +104,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
speech_file_column: Optional[str] = field(
@@ -200,7 +202,6 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
- device=batch["input_values"].device,
attention_mask=attention_mask,
min_masks=2,
)
@@ -369,7 +370,8 @@ def normalize(batch):
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
model = Wav2Vec2ForPreTraining(config)
diff --git a/examples/research_projects/xtreme-s/run_xtreme_s.py b/examples/research_projects/xtreme-s/run_xtreme_s.py
index a186d4b7cee7..16fc1ac8a39c 100644
--- a/examples/research_projects/xtreme-s/run_xtreme_s.py
+++ b/examples/research_projects/xtreme-s/run_xtreme_s.py
@@ -89,7 +89,7 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None,
metadata={
- "help": "Where do you want to store the pretrained models and datasets downloaded from " "huggingface.co"
+ "help": "Where do you want to store the pretrained models and datasets downloaded from huggingface.co"
},
)
freeze_feature_encoder: bool = field(
@@ -115,9 +115,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -127,8 +129,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -162,8 +167,10 @@ class DataTrainingArguments:
task: str = field(
default=None,
metadata={
- "help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
- "'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
+ "help": (
+ "The task name of the benchmark to use (via the datasets library). Should be on of: "
+ "'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
+ )
},
)
language: str = field(
@@ -173,10 +180,12 @@ class DataTrainingArguments:
language_group: str = field(
default=None,
metadata={
- "help": "The language group to select a subset of languages to train on. "
- "This option is only used the 'fleurs-asr' task. Should be one of: "
- "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
- "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
+ "help": (
+ "The language group to select a subset of languages to train on. "
+ "This option is only used the 'fleurs-asr' task. Should be one of: "
+ "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
+ "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
+ )
},
)
train_split_name: str = field(
@@ -188,14 +197,15 @@ class DataTrainingArguments:
eval_split_name: str = field(
default="validation",
metadata={
- "help": "The name of the evaluation dataset split to use (via the datasets library). "
- "Defaults to 'validation'"
+ "help": (
+ "The name of the evaluation dataset split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
predict_split_name: str = field(
default="test",
metadata={
- "help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'"
+ "help": "The name of the prediction dataset split to use (via the datasets library). Defaults to 'test'"
},
)
audio_column_name: str = field(
@@ -205,8 +215,10 @@ class DataTrainingArguments:
target_column_name: str = field(
default=None,
metadata={
- "help": "The name of the dataset column containing the target data "
- "(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None."
+ "help": (
+ "The name of the dataset column containing the target data (transcription/translation/label). If None,"
+ " the name will be inferred from the task. Defaults to None."
+ )
},
)
overwrite_cache: bool = field(
@@ -219,22 +231,28 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -244,7 +262,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=30.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -253,17 +274,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -281,17 +306,21 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
per_lang_metrics: bool = field(
default=True,
metadata={
- "help": "If `True`, compute the test metrics separately for each language, and average the results. "
- "If `False` compute the average test metrics in a single pass for all languages at once."
+ "help": (
+ "If `True`, compute the test metrics separately for each language, and average the results. "
+ "If `False` compute the average test metrics in a single pass for all languages at once."
+ )
},
)
@@ -320,13 +349,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
if self.pad_labels:
label_features = [{"input_ids": feature["labels"]} for feature in features]
- with self.processor.as_target_processor():
- labels_batch = self.processor.pad(
- label_features,
- padding=self.padding,
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
- return_tensors="pt",
- )
+ labels_batch = self.processor.pad(
+ labels=label_features,
+ padding=self.padding,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
@@ -446,7 +474,7 @@ def main():
if task_name is None:
raise ValueError(
- "Set --task should be set to '' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
+ "Set --task should be set to '' (e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if lang_id is None:
raise ValueError(
@@ -481,9 +509,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if target_column_name not in raw_datasets["train"].column_names:
@@ -903,7 +931,10 @@ def compute_classification_metric(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": task_name,
"tags": [task_name, data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
"language": data_args.language,
}
diff --git a/examples/tensorflow/README.md b/examples/tensorflow/README.md
index 967a1a8b7869..7936e3d46509 100644
--- a/examples/tensorflow/README.md
+++ b/examples/tensorflow/README.md
@@ -15,7 +15,7 @@ limitations under the License.
# Examples
-This folder contains actively maintained examples of use of š¤ Transformers organized into different NLP tasks. All examples in this folder are **TensorFlow** examples, and are written using native Keras rather than classes like `TFTrainer`, which we now consider deprecated. If you've previously only used š¤ Transformers via `TFTrainer`, we highly recommend taking a look at the new style - we think it's a big improvement!
+This folder contains actively maintained examples of use of š¤ Transformers organized into different ML tasks. All examples in this folder are **TensorFlow** examples, and are written using native Keras rather than classes like `TFTrainer`, which we now consider deprecated. If you've previously only used š¤ Transformers via `TFTrainer`, we highly recommend taking a look at the new style - we think it's a big improvement!
In addition, all scripts here now support the [š¤ Datasets](https://github.com/huggingface/datasets) library - you can grab entire datasets just by changing one command-line argument!
diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py
index 3598ad668a96..3f12683d10d9 100755
--- a/examples/tensorflow/language-modeling/run_clm.py
+++ b/examples/tensorflow/language-modeling/run_clm.py
@@ -53,6 +53,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -73,8 +74,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -84,8 +86,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -109,8 +113,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -150,9 +156,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -166,15 +174,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
keep_linebreaks: bool = field(
@@ -221,6 +233,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args, framework="tensorflow")
+
# Sanity checks
if data_args.dataset_name is None and data_args.train_file is None and data_args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
@@ -412,7 +428,8 @@ def group_texts(examples):
eval_dataset = lm_datasets["validation"]
else:
logger.info(
- f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
+ f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation"
+ " as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage / 100
diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py
index 8b32070b2dd1..b421ed8e669c 100755
--- a/examples/tensorflow/language-modeling/run_mlm.py
+++ b/examples/tensorflow/language-modeling/run_mlm.py
@@ -55,6 +55,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -74,8 +75,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -85,8 +87,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -110,8 +114,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -151,8 +157,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -169,22 +177,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -229,6 +243,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args, framework="tensorflow")
+
# Sanity checks
if data_args.dataset_name is None and data_args.train_file is None and data_args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
@@ -456,7 +474,8 @@ def group_texts(examples):
eval_dataset = tokenized_datasets["validation"]
else:
logger.info(
- f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
+ f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation"
+ " as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage / 100
diff --git a/examples/tensorflow/multiple-choice/run_swag.py b/examples/tensorflow/multiple-choice/run_swag.py
index a1f39eeeb011..6ba35bd0fd20 100644
--- a/examples/tensorflow/multiple-choice/run_swag.py
+++ b/examples/tensorflow/multiple-choice/run_swag.py
@@ -44,11 +44,11 @@
set_seed,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-from transformers.utils import PaddingStrategy, check_min_version
+from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
logger = logging.getLogger(__name__)
@@ -156,8 +156,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -183,30 +185,38 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If passed, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to the maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to the maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -236,6 +246,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
diff --git a/examples/tensorflow/question-answering/requirements.txt b/examples/tensorflow/question-answering/requirements.txt
index 136ddf899b00..99aff2bb32b2 100644
--- a/examples/tensorflow/question-answering/requirements.txt
+++ b/examples/tensorflow/question-answering/requirements.txt
@@ -1,2 +1,3 @@
datasets >= 1.4.0
tensorflow >= 2.3.0
+evaluate >= 0.2.0
\ No newline at end of file
diff --git a/examples/tensorflow/question-answering/run_qa.py b/examples/tensorflow/question-answering/run_qa.py
index 877fe8800999..91293aefb35f 100755
--- a/examples/tensorflow/question-answering/run_qa.py
+++ b/examples/tensorflow/question-answering/run_qa.py
@@ -26,8 +26,9 @@
from typing import Optional
import tensorflow as tf
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -41,12 +42,12 @@
TFTrainingArguments,
set_seed,
)
-from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version
+from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version, send_example_telemetry
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
logger = logging.getLogger(__name__)
@@ -78,8 +79,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -115,37 +118,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -154,9 +166,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -170,8 +184,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -227,6 +243,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
@@ -330,9 +350,9 @@ def main():
# region Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# endregion
@@ -581,7 +601,7 @@ def post_processing_function(examples, features, predictions, stage="eval"):
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
- metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
+ metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
diff --git a/examples/tensorflow/summarization/requirements.txt b/examples/tensorflow/summarization/requirements.txt
new file mode 100644
index 000000000000..99aff2bb32b2
--- /dev/null
+++ b/examples/tensorflow/summarization/requirements.txt
@@ -0,0 +1,3 @@
+datasets >= 1.4.0
+tensorflow >= 2.3.0
+evaluate >= 0.2.0
\ No newline at end of file
diff --git a/examples/tensorflow/summarization/run_summarization.py b/examples/tensorflow/summarization/run_summarization.py
index 6c4f1e5a9ed9..6d4cf99e6782 100644
--- a/examples/tensorflow/summarization/run_summarization.py
+++ b/examples/tensorflow/summarization/run_summarization.py
@@ -29,9 +29,10 @@
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
import tensorflow as tf
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from tqdm import tqdm
+import evaluate
import transformers
from filelock import FileLock
from transformers import (
@@ -44,13 +45,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -99,8 +100,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -131,14 +134,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -151,60 +155,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -247,6 +267,7 @@ def __post_init__(self):
"xglue": ("news_body", "news_title"),
"xsum": ("document", "summary"),
"wiki_summary": ("article", "highlights"),
+ "multi_news": ("document", "summary"),
}
# endregion
@@ -329,6 +350,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args, framework="tensorflow")
# endregion
# region Logging
@@ -479,9 +504,8 @@ def preprocess_function(examples):
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -610,7 +634,7 @@ def masked_sparse_categorical_crossentropy(y_true, y_pred):
# endregion
# region Metric
- metric = load_metric("rouge")
+ metric = evaluate.load("rouge")
# endregion
# region Training
@@ -656,10 +680,7 @@ def masked_sparse_categorical_crossentropy(y_true, y_pred):
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
result = metric.compute(use_stemmer=True)
- # Extract a few results from ROUGE
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
-
- result = {k: round(v, 4) for k, v in result.items()}
+ result = {k: round(v * 100, 4) for k, v in result.items()}
logger.info(result)
# endregion
diff --git a/examples/tensorflow/text-classification/requirements.txt b/examples/tensorflow/text-classification/requirements.txt
index 03d42cc5c89b..494a82127ab0 100644
--- a/examples/tensorflow/text-classification/requirements.txt
+++ b/examples/tensorflow/text-classification/requirements.txt
@@ -1,4 +1,5 @@
datasets >= 1.1.3
sentencepiece != 0.1.92
protobuf
-tensorflow >= 2.3
\ No newline at end of file
+tensorflow >= 2.3
+evaluate >= 0.2.0
\ No newline at end of file
diff --git a/examples/tensorflow/text-classification/run_glue.py b/examples/tensorflow/text-classification/run_glue.py
index c36476120eab..9fb0b3f8e434 100644
--- a/examples/tensorflow/text-classification/run_glue.py
+++ b/examples/tensorflow/text-classification/run_glue.py
@@ -24,8 +24,9 @@
import numpy as np
import tensorflow as tf
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -39,7 +40,7 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
# region Helper functions
@@ -61,7 +62,7 @@ def on_epoch_end(self, epoch, logs=None):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
task_to_keys = {
"cola": ("sentence", None),
@@ -99,8 +100,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -109,29 +112,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
@@ -171,8 +182,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -194,6 +207,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args, framework="tensorflow")
+
if not (training_args.do_train or training_args.do_eval or training_args.do_predict):
exit("Must specify at least one of --do_train, --do_eval or --do_predict!")
# endregion
@@ -350,7 +367,7 @@ def preprocess_function(examples):
# endregion
# region Metric function
- metric = load_metric("glue", data_args.task_name)
+ metric = evaluate.load("glue", data_args.task_name)
def compute_metrics(preds, label_ids):
preds = preds["logits"]
diff --git a/examples/tensorflow/text-classification/run_text_classification.py b/examples/tensorflow/text-classification/run_text_classification.py
index 3f3d64b6236d..b5d19032971c 100644
--- a/examples/tensorflow/text-classification/run_text_classification.py
+++ b/examples/tensorflow/text-classification/run_text_classification.py
@@ -37,7 +37,7 @@
TFTrainingArguments,
set_seed,
)
-from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME
+from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, send_example_telemetry
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # Reduce the amount of console output from TF
@@ -85,8 +85,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -95,30 +97,38 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
- "Data will always be padded when using TPUs."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "Data will always be padded when using TPUs."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of test examples to this "
+ "value if set."
+ )
},
)
@@ -162,8 +172,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -184,6 +196,11 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_text_classification", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
@@ -330,8 +347,8 @@ def main():
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
label_to_id = {v: i for i, v in enumerate(label_list)}
elif not is_regression:
diff --git a/examples/tensorflow/token-classification/requirements.txt b/examples/tensorflow/token-classification/requirements.txt
new file mode 100644
index 000000000000..99aff2bb32b2
--- /dev/null
+++ b/examples/tensorflow/token-classification/requirements.txt
@@ -0,0 +1,3 @@
+datasets >= 1.4.0
+tensorflow >= 2.3.0
+evaluate >= 0.2.0
\ No newline at end of file
diff --git a/examples/tensorflow/token-classification/run_ner.py b/examples/tensorflow/token-classification/run_ner.py
index e580ed94b061..caa47e115a4b 100644
--- a/examples/tensorflow/token-classification/run_ner.py
+++ b/examples/tensorflow/token-classification/run_ner.py
@@ -27,8 +27,9 @@
import datasets
import numpy as np
import tensorflow as tf
-from datasets import ClassLabel, load_dataset, load_metric
+from datasets import ClassLabel, load_dataset
+import evaluate
import transformers
from transformers import (
CONFIG_MAPPING,
@@ -41,6 +42,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -80,8 +82,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -127,37 +131,47 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -240,6 +254,10 @@ def main():
# region Argument Parsing
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args, framework="tensorflow")
# endregion
# region Setup logging
@@ -461,7 +479,7 @@ def dummy_loss(y_true, y_pred):
# endregion
# Metrics
- metric = load_metric("seqeval")
+ metric = evaluate.load("seqeval")
def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
diff --git a/examples/tensorflow/translation/requirements.txt b/examples/tensorflow/translation/requirements.txt
new file mode 100644
index 000000000000..99aff2bb32b2
--- /dev/null
+++ b/examples/tensorflow/translation/requirements.txt
@@ -0,0 +1,3 @@
+datasets >= 1.4.0
+tensorflow >= 2.3.0
+evaluate >= 0.2.0
\ No newline at end of file
diff --git a/examples/tensorflow/translation/run_translation.py b/examples/tensorflow/translation/run_translation.py
index f81148a4af0b..7f5eb9eb9def 100644
--- a/examples/tensorflow/translation/run_translation.py
+++ b/examples/tensorflow/translation/run_translation.py
@@ -28,9 +28,10 @@
import datasets
import numpy as np
import tensorflow as tf
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
from tqdm import tqdm
+import evaluate
import transformers
from transformers import (
AutoConfig,
@@ -47,13 +48,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.22.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -93,8 +94,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -119,14 +122,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -139,60 +143,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -299,6 +319,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation", model_args, data_args, framework="tensorflow")
# endregion
# region Logging
@@ -434,9 +458,8 @@ def preprocess_function(examples):
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
- # Setup the tokenizer for targets
- with tokenizer.as_target_tokenizer():
- labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
+ # Tokenize targets with the `text_target` keyword argument
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
@@ -567,7 +590,7 @@ def masked_sparse_categorical_crossentropy(y_true, y_pred):
# endregion
# region Metric and postprocessing
- metric = load_metric("sacrebleu")
+ metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
diff --git a/model_cards/README.md b/model_cards/README.md
index 4bf6ac6186f3..b2ee3e25a5d3 100644
--- a/model_cards/README.md
+++ b/model_cards/README.md
@@ -15,11 +15,7 @@ You can either:
**What if you want to create or update a model card for a model you don't have write access to?**
-In that case, given that we don't have a Pull request system yet on huggingface.co (š¤Æ),
-you can open an issue here, post the card's content, and tag the model author(s) and/or the Hugging Face team.
-
-We might implement a more seamless process at some point, so your early feedback is precious!
-Please let us know of any suggestion.
+In that case, you can open a [Hub pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)! Check out the [announcement](https://huggingface.co/blog/community-update) of this feature for more details š¤.
### What happened to the model cards here?
diff --git a/notebooks/README.md b/notebooks/README.md
index 073d2987027a..1a25cdd8044d 100644
--- a/notebooks/README.md
+++ b/notebooks/README.md
@@ -61,7 +61,9 @@ You can open any page of the documentation as a notebook in colab (there is a bu
| [How to export model to ONNX](https://github.com/huggingface/notebooks/blob/main/examples/onnx-export.ipynb)| Highlight how to export and run inference workloads through ONNX |
| [How to use Benchmarks](https://github.com/huggingface/notebooks/blob/main/examples/benchmark.ipynb)| How to benchmark models with transformers | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/benchmark.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/benchmark.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| How Reformer pushes the limits of language modeling | [](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
-| [How to fine-tune a model on image classification](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) | Show how to preprocess the data and fine-tune any pretrained Vision model on Image Classification | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)|
+| [How to fine-tune a model on image classification (Torchvision)](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) | Show how to preprocess the data using Torchvision and fine-tune any pretrained Vision model on Image Classification | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)|
+| [How to fine-tune a model on image classification (Albumentations)](https://github.com/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb) | Show how to preprocess the data using Albumentations and fine-tune any pretrained Vision model on Image Classification | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb)|
+| [How to perform zero-shot object detection with OWL-ViT](https://github.com/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb) | Show how to perform zero-shot object detection on images with text queries| [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb)|
### TensorFlow Examples
@@ -88,4 +90,4 @@ You can open any page of the documentation as a notebook in colab (there is a bu
## Community notebooks:
-More notebooks developed by the community are available [here](community#community-notebooks).
+More notebooks developed by the community are available [here](https:hf.co/docs/transformers/community#community-notebooks).
diff --git a/scripts/stale.py b/scripts/stale.py
index 056fbb469941..88d7efbd3b29 100644
--- a/scripts/stale.py
+++ b/scripts/stale.py
@@ -24,6 +24,7 @@
LABELS_TO_EXEMPT = [
"good first issue",
"good second issue",
+ "good difficult issue",
"feature request",
"new model",
"wip",
diff --git a/scripts/tatoeba/README.md b/scripts/tatoeba/README.md
index b86caf51d725..7c492ec4f46e 100644
--- a/scripts/tatoeba/README.md
+++ b/scripts/tatoeba/README.md
@@ -57,7 +57,7 @@ To upload all converted models,
2. Login to `transformers-cli`
```bash
-transformers-cli login
+huggingface-cli login
```
3. Run the `upload_models` script
diff --git a/scripts/tatoeba/upload_models.sh b/scripts/tatoeba/upload_models.sh
index 07c21edcbd51..536eb5bc68c4 100755
--- a/scripts/tatoeba/upload_models.sh
+++ b/scripts/tatoeba/upload_models.sh
@@ -2,7 +2,7 @@
for FILE in converted/*; do
model_name=`basename $FILE`
- transformers-cli repo create $model_name -y
+ huggingface-cli repo create $model_name -y
git clone https://huggingface.co/Helsinki-NLP/$model_name
mv $FILE/* $model_name/
cd $model_name
diff --git a/setup.py b/setup.py
index bb3598fda20d..05ec2c7617fd 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
-
+
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
for the post-release and run `make fix-copies` on the main branch as well.
@@ -27,12 +27,13 @@
3. Unpin specific versions from setup.py that use a git install.
-4. Commit these changes with the message: "Release: " and push.
+4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the
+ message: "Release: " and push.
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' "
- Push the tag to git: git push --tags origin main
+ Push the tag to git: git push --tags origin v-release
7. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).
@@ -62,7 +63,7 @@
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
-11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
+11. Run `make post-release` then run `make fix-copies`. If you were on a branch for the release,
you need to go back to main before executing this.
"""
@@ -96,23 +97,26 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
- "black~=22.0",
+ "accelerate>=0.10.0",
+ "black==22.3",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
"dataclasses",
"datasets",
- "deepspeed>=0.6.0",
+ "deepspeed>=0.6.5",
+ "dill<0.3.5",
+ "evaluate>=0.2.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
"filelock",
"flake8>=3.8.3",
- "flax>=0.3.5",
+ "flax>=0.4.1",
"ftfy",
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.1.0,<1.0",
+ "huggingface-hub>=0.8.1,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
@@ -129,14 +133,14 @@
"packaging>=20.0",
"parameterized",
"phonemizer",
- "protobuf",
+ "protobuf<=3.20.1",
"psutil",
"pyyaml>=5.1",
"pydantic",
"pytest",
"pytest-timeout",
"pytest-xdist",
- "python>=3.6.0",
+ "python>=3.7.0",
"ray[tune]",
"regex!=2019.12.17",
"requests",
@@ -152,11 +156,12 @@
"starlette",
"tensorflow-cpu>=2.3",
"tensorflow>=2.3",
+ "tensorflow-text",
"tf2onnx",
"timeout-decorator",
"timm",
"tokenizers>=0.11.1,!=0.11.3,<0.13",
- "torch>=1.0",
+ "torch>=1.0,!=0.12.0",
"torchaudio",
"pyctcdecode>=0.3.0",
"tqdm>=4.27",
@@ -235,10 +240,11 @@ def run(self):
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
extras["sklearn"] = deps_list("scikit-learn")
-extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
-extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")
+extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text")
+extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text")
extras["torch"] = deps_list("torch")
+extras["accelerate"] = deps_list("accelerate")
if os.name == "nt": # windows
extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows
@@ -254,7 +260,7 @@ def run(self):
extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker")
-extras["deepspeed"] = deps_list("deepspeed")
+extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
extras["fairscale"] = deps_list("fairscale")
extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]")
@@ -282,6 +288,8 @@ def run(self):
"parameterized",
"psutil",
"datasets",
+ "dill",
+ "evaluate",
"pytest-timeout",
"black",
"sacrebleu",
@@ -289,8 +297,9 @@ def run(self):
"nltk",
"GitPython",
"hf-doc-builder",
+ "protobuf", # Can be removed once we can unpin protobuf
"sacremoses",
- "rjieba"
+ "rjieba",
)
+ extras["retrieval"]
+ extras["modelcreation"]
@@ -311,6 +320,7 @@ def run(self):
+ extras["integrations"]
+ extras["timm"]
+ extras["codecarbon"]
+ + extras["accelerate"]
)
# Might need to add doc-builder and some specific deps in the future
@@ -320,8 +330,8 @@ def run(self):
extras["docs"] = extras["all"] + extras["docs_specific"]
extras["dev-torch"] = (
- extras['testing']
- + extras['torch']
+ extras["testing"]
+ + extras["torch"]
+ extras["sentencepiece"]
+ extras["tokenizers"]
+ extras["torch-speech"]
@@ -337,17 +347,17 @@ def run(self):
+ extras["onnxruntime"]
)
extras["dev-tensorflow"] = (
- extras['testing']
- + extras['tf']
- + extras["sentencepiece"]
- + extras["tokenizers"]
- + extras["vision"]
- + extras["quality"]
- + extras["docs_specific"]
- + extras["sklearn"]
- + extras["modelcreation"]
- + extras["onnx"]
- + extras["tf-speech"]
+ extras["testing"]
+ + extras["tf"]
+ + extras["sentencepiece"]
+ + extras["tokenizers"]
+ + extras["vision"]
+ + extras["quality"]
+ + extras["docs_specific"]
+ + extras["sklearn"]
+ + extras["modelcreation"]
+ + extras["onnx"]
+ + extras["tf-speech"]
)
extras["dev"] = (
extras["all"]
@@ -390,7 +400,7 @@ def run(self):
setup(
name="transformers",
- version="4.19.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="4.22.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
@@ -401,7 +411,6 @@ def run(self):
url="https://github.com/huggingface/transformers",
package_dir={"": "src"},
packages=find_packages("src"),
- package_data={"transformers": ["py.typed"]},
zip_safe=False,
extras_require=extras,
entry_points={"console_scripts": ["transformers-cli=transformers.commands.transformers_cli:main"]},
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 6d976ef6f2d7..0a97952b18b8 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -22,18 +22,20 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
-__version__ = "4.19.0.dev0"
+__version__ = "4.22.0.dev0"
from typing import TYPE_CHECKING
# Check the dependencies satisfy the minimal versions required.
from . import dependency_versions_check
from .utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_scatter_available,
is_sentencepiece_available,
is_speech_available,
+ is_tensorflow_text_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
@@ -155,6 +157,7 @@
"BlenderbotSmallConfig",
"BlenderbotSmallTokenizer",
],
+ "models.bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig"],
"models.bort": [],
"models.byt5": ["ByT5Tokenizer"],
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
@@ -166,10 +169,12 @@
"CLIPTokenizer",
"CLIPVisionConfig",
],
+ "models.codegen": ["CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP", "CodeGenConfig", "CodeGenTokenizer"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"],
"models.cpm": [],
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
+ "models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"],
"models.data2vec": [
"DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP",
@@ -197,13 +202,28 @@
"models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"],
"models.encoder_decoder": ["EncoderDecoderConfig"],
"models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"],
+ "models.flava": [
+ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "FlavaConfig",
+ "FlavaImageCodebookConfig",
+ "FlavaImageConfig",
+ "FlavaMultimodalConfig",
+ "FlavaTextConfig",
+ ],
"models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"],
"models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"],
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
+ "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
+ "models.groupvit": [
+ "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "GroupViTConfig",
+ "GroupViTTextConfig",
+ "GroupViTVisionConfig",
+ ],
"models.herbert": ["HerbertTokenizer"],
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
@@ -216,9 +236,18 @@
"LayoutLMv2Processor",
"LayoutLMv2Tokenizer",
],
+ "models.layoutlmv3": [
+ "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "LayoutLMv3Config",
+ "LayoutLMv3FeatureExtractor",
+ "LayoutLMv3Processor",
+ "LayoutLMv3Tokenizer",
+ ],
"models.layoutxlm": ["LayoutXLMProcessor"],
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
+ "models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
+ "models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
@@ -226,18 +255,31 @@
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
+ "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.megatron_gpt2": [],
"models.mluke": [],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
+ "models.mobilevit": ["MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTConfig"],
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
"models.mt5": ["MT5Config"],
+ "models.mvp": ["MvpConfig", "MvpTokenizer"],
+ "models.nezha": ["NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP", "NezhaConfig"],
+ "models.nllb": [],
"models.nystromformer": [
"NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"NystromformerConfig",
],
"models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"],
+ "models.opt": ["OPTConfig"],
+ "models.owlvit": [
+ "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "OwlViTConfig",
+ "OwlViTProcessor",
+ "OwlViTTextConfig",
+ "OwlViTVisionConfig",
+ ],
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
"models.phobert": ["PhobertTokenizer"],
@@ -271,9 +313,14 @@
"models.splinter": ["SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SplinterConfig", "SplinterTokenizer"],
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
"models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"],
+ "models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
"models.tapex": ["TapexTokenizer"],
+ "models.trajectory_transformer": [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "TrajectoryTransformerConfig",
+ ],
"models.transfo_xl": [
"TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TransfoXLConfig",
@@ -294,6 +341,7 @@
"UniSpeechSatConfig",
],
"models.van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
+ "models.videomae": ["VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VideoMAEConfig"],
"models.vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig", "ViltFeatureExtractor", "ViltProcessor"],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
@@ -308,6 +356,10 @@
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
+ "models.wav2vec2_conformer": [
+ "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2ConformerConfig",
+ ],
"models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"],
"models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
"models.wavlm": [
@@ -347,6 +399,7 @@
"TextGenerationPipeline",
"TokenClassificationPipeline",
"TranslationPipeline",
+ "VisualQuestionAnsweringPipeline",
"ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline",
"pipeline",
@@ -371,7 +424,7 @@
"TrainerControl",
"TrainerState",
],
- "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
+ "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"],
"training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"],
@@ -388,7 +441,6 @@
"TensorType",
"add_end_docstrings",
"add_start_docstrings",
- "cached_path",
"is_apex_available",
"is_datasets_available",
"is_faiss_available",
@@ -401,6 +453,7 @@
"is_sentencepiece_available",
"is_sklearn_available",
"is_speech_available",
+ "is_tensorflow_text_available",
"is_tf_available",
"is_timm_available",
"is_tokenizers_available",
@@ -412,7 +465,16 @@
}
# sentencepiece-backed objects
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_sentencepiece_objects
+
+ _import_structure["utils.dummy_sentencepiece_objects"] = [
+ name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.albert"].append("AlbertTokenizer")
_import_structure["models.barthez"].append("BarthezTokenizer")
_import_structure["models.bartpho"].append("BartphoTokenizer")
@@ -426,6 +488,7 @@
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
+ _import_structure["models.nllb"].append("NllbTokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mluke"].append("MLukeTokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer")
@@ -439,16 +502,19 @@
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
_import_structure["models.xlnet"].append("XLNetTokenizer")
-else:
- from .utils import dummy_sentencepiece_objects
-
- _import_structure["utils.dummy_sentencepiece_objects"] = [
- name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
- ]
# tokenizers-backed objects
-if is_tokenizers_available():
- # Fast tokenizers
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tokenizers_objects
+
+ _import_structure["utils.dummy_tokenizers_objects"] = [
+ name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ # Fast tokenizers structure
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
_import_structure["models.barthez"].append("BarthezTokenizerFast")
@@ -456,8 +522,10 @@
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
+ _import_structure["models.bloom"].append("BloomTokenizerFast")
_import_structure["models.camembert"].append("CamembertTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
+ _import_structure["models.codegen"].append("CodeGenTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
_import_structure["models.cpm"].append("CpmTokenizerFast")
_import_structure["models.deberta"].append("DebertaTokenizerFast")
@@ -470,9 +538,11 @@
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.funnel"].append("FunnelTokenizerFast")
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
+ _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast")
_import_structure["models.herbert"].append("HerbertTokenizerFast")
_import_structure["models.layoutlm"].append("LayoutLMTokenizerFast")
_import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast")
+ _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
_import_structure["models.led"].append("LEDTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast")
@@ -482,6 +552,8 @@
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast")
+ _import_structure["models.mvp"].append("MvpTokenizerFast")
+ _import_structure["models.nllb"].append("NllbTokenizerFast")
_import_structure["models.openai"].append("OpenAIGPTTokenizerFast")
_import_structure["models.pegasus"].append("PegasusTokenizerFast")
_import_structure["models.realm"].append("RealmTokenizerFast")
@@ -498,43 +570,69 @@
_import_structure["models.xlnet"].append("XLNetTokenizerFast")
_import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
-else:
- from .utils import dummy_tokenizers_objects
- _import_structure["utils.dummy_tokenizers_objects"] = [
- name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
- ]
-
-if is_sentencepiece_available() and is_tokenizers_available():
- _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
-else:
+try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_sentencepiece_and_tokenizers_objects
_import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
]
+else:
+ _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
# Speech-specific objects
-if is_speech_available():
- _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
-else:
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_speech_objects
_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
+else:
+ _import_structure["models.mctct"].append("MCTCTFeatureExtractor")
+ _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
-if is_sentencepiece_available() and is_speech_available():
- _import_structure["models.speech_to_text"].append("Speech2TextProcessor")
+# Tensorflow-text-specific objects
+try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tensorflow_text_objects
+
+ _import_structure["utils.dummy_tensorflow_text_objects"] = [
+ name for name in dir(dummy_tensorflow_text_objects) if not name.startswith("_")
+ ]
else:
+ _import_structure["models.bert"].append("TFBertTokenizer")
+
+try:
+ if not (is_sentencepiece_available() and is_speech_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_sentencepiece_and_speech_objects
_import_structure["utils.dummy_sentencepiece_and_speech_objects"] = [
name for name in dir(dummy_sentencepiece_and_speech_objects) if not name.startswith("_")
]
+else:
+ _import_structure["models.speech_to_text"].append("Speech2TextProcessor")
# Vision-specific objects
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_vision_objects
+
+ _import_structure["utils.dummy_vision_objects"] = [
+ name for name in dir(dummy_vision_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].append("BeitFeatureExtractor")
_import_structure["models.clip"].append("CLIPFeatureExtractor")
@@ -543,28 +641,35 @@
_import_structure["models.deit"].append("DeiTFeatureExtractor")
_import_structure["models.detr"].append("DetrFeatureExtractor")
_import_structure["models.dpt"].append("DPTFeatureExtractor")
+ _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor"])
_import_structure["models.glpn"].append("GLPNFeatureExtractor")
_import_structure["models.imagegpt"].append("ImageGPTFeatureExtractor")
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
- _import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
- _import_structure["models.layoutxlm"].append("LayoutXLMProcessor")
+ _import_structure["models.layoutlmv3"].append("LayoutLMv3FeatureExtractor")
+ _import_structure["models.levit"].append("LevitFeatureExtractor")
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor")
+ _import_structure["models.mobilevit"].append("MobileViTFeatureExtractor")
+ _import_structure["models.owlvit"].append("OwlViTFeatureExtractor")
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
_import_structure["models.poolformer"].append("PoolFormerFeatureExtractor")
_import_structure["models.segformer"].append("SegformerFeatureExtractor")
+ _import_structure["models.videomae"].append("VideoMAEFeatureExtractor")
_import_structure["models.vilt"].append("ViltFeatureExtractor")
_import_structure["models.vilt"].append("ViltProcessor")
_import_structure["models.vit"].append("ViTFeatureExtractor")
_import_structure["models.yolos"].append("YolosFeatureExtractor")
-else:
- from .utils import dummy_vision_objects
-
- _import_structure["utils.dummy_vision_objects"] = [
- name for name in dir(dummy_vision_objects) if not name.startswith("_")
- ]
# Timm-backed objects
-if is_timm_available() and is_vision_available():
+try:
+ if not (is_timm_available() and is_vision_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_timm_objects
+
+ _import_structure["utils.dummy_timm_objects"] = [
+ name for name in dir(dummy_timm_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -574,14 +679,17 @@
"DetrPreTrainedModel",
]
)
-else:
- from .utils import dummy_timm_objects
- _import_structure["utils.dummy_timm_objects"] = [
- name for name in dir(dummy_timm_objects) if not name.startswith("_")
- ]
+try:
+ if not is_scatter_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_scatter_objects
-if is_scatter_available():
+ _import_structure["utils.dummy_scatter_objects"] = [
+ name for name in dir(dummy_scatter_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.tapas"].extend(
[
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -593,16 +701,17 @@
"load_tf_weights_in_tapas",
]
)
-else:
- from .utils import dummy_scatter_objects
-
- _import_structure["utils.dummy_scatter_objects"] = [
- name for name in dir(dummy_scatter_objects) if not name.startswith("_")
- ]
# PyTorch-backed objects
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_pt_objects
+
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
+else:
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
@@ -641,6 +750,7 @@
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
+ "TypicalLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria",
@@ -690,7 +800,9 @@
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
@@ -715,7 +827,9 @@
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
+ "AutoModelForVideoClassification",
"AutoModelForVision2Seq",
+ "AutoModelForVisualQuestionAnswering",
"AutoModelWithLMHead",
]
)
@@ -731,6 +845,17 @@
"PretrainedBartModel",
]
)
+ _import_structure["models.mvp"].extend(
+ [
+ "MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MvpForCausalLM",
+ "MvpForConditionalGeneration",
+ "MvpForQuestionAnswering",
+ "MvpForSequenceClassification",
+ "MvpModel",
+ "MvpPreTrainedModel",
+ ]
+ )
_import_structure["models.beit"].extend(
[
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -793,6 +918,16 @@
"BigBirdPegasusPreTrainedModel",
]
)
+ _import_structure["models.bloom"].extend(
+ [
+ "BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ "BloomForSequenceClassification",
+ "BloomForTokenClassification",
+ ]
+ )
_import_structure["models.blenderbot"].extend(
[
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -876,6 +1011,14 @@
"CTRLPreTrainedModel",
]
)
+ _import_structure["models.cvt"].extend(
+ [
+ "CVT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "CvtForImageClassification",
+ "CvtModel",
+ "CvtPreTrainedModel",
+ ]
+ )
_import_structure["models.data2vec"].extend(
[
"DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -916,6 +1059,7 @@
[
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaV2ForMaskedLM",
+ "DebertaV2ForMultipleChoice",
"DebertaV2ForQuestionAnswering",
"DebertaV2ForSequenceClassification",
"DebertaV2ForTokenClassification",
@@ -1005,6 +1149,18 @@
"FlaubertWithLMHeadModel",
]
)
+ _import_structure["models.flava"].extend(
+ [
+ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "FlavaForPreTraining",
+ "FlavaImageCodebook",
+ "FlavaImageModel",
+ "FlavaModel",
+ "FlavaMultimodalModel",
+ "FlavaPreTrainedModel",
+ "FlavaTextModel",
+ ]
+ )
_import_structure["models.fnet"].extend(
[
"FNET_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1066,6 +1222,15 @@
"load_tf_weights_in_gpt_neo",
]
)
+ _import_structure["models.gpt_neox"].extend(
+ [
+ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GPTNeoXForCausalLM",
+ "GPTNeoXLayer",
+ "GPTNeoXModel",
+ "GPTNeoXPreTrainedModel",
+ ]
+ )
_import_structure["models.gptj"].extend(
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1076,6 +1241,23 @@
"GPTJPreTrainedModel",
]
)
+ _import_structure["models.groupvit"].extend(
+ [
+ "GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GroupViTModel",
+ "GroupViTPreTrainedModel",
+ "GroupViTTextModel",
+ "GroupViTVisionModel",
+ ]
+ )
+ _import_structure["models.codegen"].extend(
+ [
+ "CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "CodeGenForCausalLM",
+ "CodeGenModel",
+ "CodeGenPreTrainedModel",
+ ]
+ )
_import_structure["models.hubert"].extend(
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1127,6 +1309,16 @@
"LayoutLMv2PreTrainedModel",
]
)
+ _import_structure["models.layoutlmv3"].extend(
+ [
+ "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LayoutLMv3ForQuestionAnswering",
+ "LayoutLMv3ForSequenceClassification",
+ "LayoutLMv3ForTokenClassification",
+ "LayoutLMv3Model",
+ "LayoutLMv3PreTrainedModel",
+ ]
+ )
_import_structure["models.led"].extend(
[
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1137,6 +1329,15 @@
"LEDPreTrainedModel",
]
)
+ _import_structure["models.levit"].extend(
+ [
+ "LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LevitForImageClassification",
+ "LevitForImageClassificationWithTeacher",
+ "LevitModel",
+ "LevitPreTrainedModel",
+ ]
+ )
_import_structure["models.longformer"].extend(
[
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1150,12 +1351,25 @@
"LongformerSelfAttention",
]
)
+ _import_structure["models.longt5"].extend(
+ [
+ "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LongT5EncoderModel",
+ "LongT5ForConditionalGeneration",
+ "LongT5Model",
+ "LongT5PreTrainedModel",
+ ]
+ )
_import_structure["models.luke"].extend(
[
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
+ "LukeForMultipleChoice",
+ "LukeForQuestionAnswering",
+ "LukeForSequenceClassification",
+ "LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
@@ -1199,6 +1413,14 @@
"MBartPreTrainedModel",
]
)
+ _import_structure["models.mctct"].extend(
+ [
+ "MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MCTCTForCTC",
+ "MCTCTModel",
+ "MCTCTPreTrainedModel",
+ ]
+ )
_import_structure["models.megatron_bert"].extend(
[
"MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1231,6 +1453,15 @@
"load_tf_weights_in_mobilebert",
]
)
+ _import_structure["models.mobilevit"].extend(
+ [
+ "MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MobileViTForImageClassification",
+ "MobileViTForSemanticSegmentation",
+ "MobileViTModel",
+ "MobileViTPreTrainedModel",
+ ]
+ )
_import_structure["models.mpnet"].extend(
[
"MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1245,6 +1476,20 @@
]
)
_import_structure["models.mt5"].extend(["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"])
+ _import_structure["models.nezha"].extend(
+ [
+ "NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "NezhaForMaskedLM",
+ "NezhaForPreTraining",
+ "NezhaForNextSentencePrediction",
+ "NezhaForMultipleChoice",
+ "NezhaForQuestionAnswering",
+ "NezhaForSequenceClassification",
+ "NezhaForTokenClassification",
+ "NezhaModel",
+ "NezhaPreTrainedModel",
+ ]
+ )
_import_structure["models.nystromformer"].extend(
[
"NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1269,6 +1514,25 @@
"load_tf_weights_in_openai_gpt",
]
)
+ _import_structure["models.opt"].extend(
+ [
+ "OPT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OPTForCausalLM",
+ "OPTModel",
+ "OPTPreTrainedModel",
+ "OPTForSequenceClassification",
+ ]
+ )
+ _import_structure["models.owlvit"].extend(
+ [
+ "OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OwlViTModel",
+ "OwlViTPreTrainedModel",
+ "OwlViTTextModel",
+ "OwlViTVisionModel",
+ "OwlViTForObjectDetection",
+ ]
+ )
_import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
)
@@ -1465,6 +1729,7 @@
_import_structure["models.splinter"].extend(
[
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SplinterForPreTraining",
"SplinterForQuestionAnswering",
"SplinterLayer",
"SplinterModel",
@@ -1493,6 +1758,15 @@
"SwinPreTrainedModel",
]
)
+ _import_structure["models.swinv2"].extend(
+ [
+ "SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Swinv2ForImageClassification",
+ "Swinv2ForMaskedImageModeling",
+ "Swinv2Model",
+ "Swinv2PreTrainedModel",
+ ]
+ )
_import_structure["models.t5"].extend(
[
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1503,6 +1777,13 @@
"load_tf_weights_in_t5",
]
)
+ _import_structure["models.trajectory_transformer"].extend(
+ [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ ]
+ )
_import_structure["models.transfo_xl"].extend(
[
"TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1552,6 +1833,7 @@
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval",
"ViltForImagesAndTextClassification",
+ "ViltForTokenClassification",
"ViltForMaskedLM",
"ViltForQuestionAnswering",
"ViltLayer",
@@ -1592,6 +1874,15 @@
"ViTMAEPreTrainedModel",
]
)
+ _import_structure["models.videomae"].extend(
+ [
+ "VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "VideoMAEForPreTraining",
+ "VideoMAEModel",
+ "VideoMAEPreTrainedModel",
+ "VideoMAEForVideoClassification",
+ ]
+ )
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1605,6 +1896,18 @@
"Wav2Vec2PreTrainedModel",
]
)
+ _import_structure["models.wav2vec2_conformer"].extend(
+ [
+ "WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+ ]
+ )
_import_structure["models.wavlm"].extend(
[
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1723,13 +2026,16 @@
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
-else:
- from .utils import dummy_pt_objects
-
- _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
# TensorFlow-backed objects
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tf_objects
+
+ _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
+else:
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
@@ -1775,11 +2081,13 @@
[
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"TF_MODEL_FOR_PRETRAINING_MAPPING",
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
@@ -1793,6 +2101,7 @@
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice",
+ "TFAutoModelForNextSentencePrediction",
"TFAutoModelForPreTraining",
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
@@ -1881,6 +2190,7 @@
_import_structure["models.data2vec"].extend(
[
"TFData2VecVisionForImageClassification",
+ "TFData2VecVisionForSemanticSegmentation",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
@@ -1907,6 +2217,16 @@
"TFDebertaV2PreTrainedModel",
]
)
+ _import_structure["models.deit"].extend(
+ [
+ "TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFDeiTForImageClassification",
+ "TFDeiTForImageClassificationWithTeacher",
+ "TFDeiTForMaskedImageModeling",
+ "TFDeiTModel",
+ "TFDeiTPreTrainedModel",
+ ]
+ )
_import_structure["models.distilbert"].extend(
[
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2080,6 +2400,13 @@
"TFOpenAIGPTPreTrainedModel",
]
)
+ _import_structure["models.opt"].extend(
+ [
+ "TFOPTForCausalLM",
+ "TFOPTModel",
+ "TFOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
@@ -2091,6 +2418,14 @@
"TFRagTokenForGeneration",
]
)
+ _import_structure["models.regnet"].extend(
+ [
+ "TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFRegNetForImageClassification",
+ "TFRegNetModel",
+ "TFRegNetPreTrainedModel",
+ ]
+ )
_import_structure["models.rembert"].extend(
[
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2105,6 +2440,14 @@
"TFRemBertPreTrainedModel",
]
)
+ _import_structure["models.resnet"].extend(
+ [
+ "TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFResNetForImageClassification",
+ "TFResNetModel",
+ "TFResNetPreTrainedModel",
+ ]
+ )
_import_structure["models.roberta"].extend(
[
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2133,6 +2476,16 @@
"TFRoFormerPreTrainedModel",
]
)
+ _import_structure["models.segformer"].extend(
+ [
+ "TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSegformerDecodeHead",
+ "TFSegformerForImageClassification",
+ "TFSegformerForSemanticSegmentation",
+ "TFSegformerModel",
+ "TFSegformerPreTrainedModel",
+ ]
+ )
_import_structure["models.speech_to_text"].extend(
[
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2141,6 +2494,15 @@
"TFSpeech2TextPreTrainedModel",
]
)
+ _import_structure["models.swin"].extend(
+ [
+ "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSwinForImageClassification",
+ "TFSwinForMaskedImageModeling",
+ "TFSwinModel",
+ "TFSwinPreTrainedModel",
+ ]
+ )
_import_structure["models.t5"].extend(
[
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2235,13 +2597,18 @@
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"]
-else:
- from .utils import dummy_tf_objects
-
- _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
# FLAX-backed objects
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_flax_objects
+
+ _import_structure["utils.dummy_flax_objects"] = [
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["generation_flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
@@ -2310,7 +2677,6 @@
"FlaxBartPreTrainedModel",
]
)
-
_import_structure["models.beit"].extend(
[
"FlaxBeitForImageClassification",
@@ -2319,6 +2685,7 @@
"FlaxBeitPreTrainedModel",
]
)
+
_import_structure["models.bert"].extend(
[
"FlaxBertForCausalLM",
@@ -2396,6 +2763,9 @@
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
+ _import_structure["models.longt5"].extend(
+ ["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
+ )
_import_structure["models.marian"].extend(
[
"FlaxMarianModel",
@@ -2412,7 +2782,14 @@
"FlaxMBartPreTrainedModel",
]
)
- _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
+ _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
+ _import_structure["models.opt"].extend(
+ [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
[
"FlaxPegasusForConditionalGeneration",
@@ -2444,7 +2821,9 @@
]
)
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
- _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
+ _import_structure["models.t5"].extend(
+ ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]
+ )
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
@@ -2468,12 +2847,6 @@
"FlaxXLMRobertaModel",
]
)
-else:
- from .utils import dummy_flax_objects
-
- _import_structure["utils.dummy_flax_objects"] = [
- name for name in dir(dummy_flax_objects) if not name.startswith("_")
- ]
# Direct imports for type-checking
@@ -2577,6 +2950,7 @@
BlenderbotSmallConfig,
BlenderbotSmallTokenizer,
)
+ from .models.bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
from .models.byt5 import ByT5Tokenizer
from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .models.canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig, CanineTokenizer
@@ -2587,9 +2961,11 @@
CLIPTokenizer,
CLIPVisionConfig,
)
+ from .models.codegen import CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP, CodeGenConfig, CodeGenTokenizer
from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer
from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer
+ from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig
from .models.data2vec import (
DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -2618,13 +2994,28 @@
from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer
from .models.encoder_decoder import EncoderDecoderConfig
from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer
+ from .models.flava import (
+ FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+ )
from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig
from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
+ from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
+ from .models.groupvit import (
+ GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ GroupViTConfig,
+ GroupViTTextConfig,
+ GroupViTVisionConfig,
+ )
from .models.herbert import HerbertTokenizer
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
@@ -2637,22 +3028,43 @@
LayoutLMv2Processor,
LayoutLMv2Tokenizer,
)
+ from .models.layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ LayoutLMv3Config,
+ LayoutLMv3FeatureExtractor,
+ LayoutLMv3Processor,
+ LayoutLMv3Tokenizer,
+ )
from .models.layoutxlm import LayoutXLMProcessor
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
+ from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
+ from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.marian import MarianConfig
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .models.mbart import MBartConfig
+ from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from .models.mmbt import MMBTConfig
from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer
+ from .models.mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer
from .models.mt5 import MT5Config
+ from .models.mvp import MvpConfig, MvpTokenizer
+ from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig
from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig
from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer
+ from .models.opt import OPTConfig
+ from .models.owlvit import (
+ OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ OwlViTConfig,
+ OwlViTProcessor,
+ OwlViTTextConfig,
+ OwlViTVisionConfig,
+ )
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
from .models.phobert import PhobertTokenizer
@@ -2683,9 +3095,14 @@
from .models.splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig, SplinterTokenizer
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
+ from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
from .models.tapex import TapexTokenizer
+ from .models.trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ TrajectoryTransformerConfig,
+ )
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig,
@@ -2696,6 +3113,7 @@
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
from .models.van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
+ from .models.videomae import VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP, VideoMAEConfig
from .models.vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig, ViltFeatureExtractor, ViltProcessor
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
@@ -2710,6 +3128,7 @@
Wav2Vec2Processor,
Wav2Vec2Tokenizer,
)
+ from .models.wav2vec2_conformer import WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2ConformerConfig
from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
@@ -2747,6 +3166,7 @@
TextGenerationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
+ VisualQuestionAnsweringPipeline,
ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline,
pipeline,
@@ -2774,7 +3194,7 @@
TrainerControl,
TrainerState,
)
- from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
+ from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
@@ -2793,7 +3213,6 @@
TensorType,
add_end_docstrings,
add_start_docstrings,
- cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
@@ -2806,6 +3225,7 @@
is_sentencepiece_available,
is_sklearn_available,
is_speech_available,
+ is_tensorflow_text_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
@@ -2815,7 +3235,12 @@
logging,
)
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_sentencepiece_objects import *
+ else:
from .models.albert import AlbertTokenizer
from .models.barthez import BarthezTokenizer
from .models.bartpho import BartphoTokenizer
@@ -2831,6 +3256,7 @@
from .models.mbart import MBart50Tokenizer, MBartTokenizer
from .models.mluke import MLukeTokenizer
from .models.mt5 import MT5Tokenizer
+ from .models.nllb import NllbTokenizer
from .models.pegasus import PegasusTokenizer
from .models.plbart import PLBartTokenizer
from .models.reformer import ReformerTokenizer
@@ -2841,10 +3267,14 @@
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
from .models.xlnet import XLNetTokenizer
- else:
- from .utils.dummy_sentencepiece_objects import *
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_tokenizers_objects import *
+ else:
+ # Fast tokenizers imports
from .models.albert import AlbertTokenizerFast
from .models.bart import BartTokenizerFast
from .models.barthez import BarthezTokenizerFast
@@ -2852,8 +3282,10 @@
from .models.big_bird import BigBirdTokenizerFast
from .models.blenderbot import BlenderbotTokenizerFast
from .models.blenderbot_small import BlenderbotSmallTokenizerFast
+ from .models.bloom import BloomTokenizerFast
from .models.camembert import CamembertTokenizerFast
from .models.clip import CLIPTokenizerFast
+ from .models.codegen import CodeGenTokenizerFast
from .models.convbert import ConvBertTokenizerFast
from .models.cpm import CpmTokenizerFast
from .models.deberta import DebertaTokenizerFast
@@ -2864,9 +3296,11 @@
from .models.fnet import FNetTokenizerFast
from .models.funnel import FunnelTokenizerFast
from .models.gpt2 import GPT2TokenizerFast
+ from .models.gpt_neox import GPTNeoXTokenizerFast
from .models.herbert import HerbertTokenizerFast
from .models.layoutlm import LayoutLMTokenizerFast
from .models.layoutlmv2 import LayoutLMv2TokenizerFast
+ from .models.layoutlmv3 import LayoutLMv3TokenizerFast
from .models.layoutxlm import LayoutXLMTokenizerFast
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
@@ -2876,6 +3310,8 @@
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast
+ from .models.mvp import MvpTokenizerFast
+ from .models.nllb import NllbTokenizerFast
from .models.openai import OpenAIGPTTokenizerFast
from .models.pegasus import PegasusTokenizerFast
from .models.realm import RealmTokenizerFast
@@ -2892,25 +3328,45 @@
from .models.xlnet import XLNetTokenizerFast
from .tokenization_utils_fast import PreTrainedTokenizerFast
+ try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummies_sentencepiece_and_tokenizers_objects import *
else:
- from .utils.dummy_tokenizers_objects import *
-
- if is_sentencepiece_available() and is_tokenizers_available():
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
- else:
- from .utils.dummies_sentencepiece_and_tokenizers_objects import *
- if is_speech_available():
- from .models.speech_to_text import Speech2TextFeatureExtractor
- else:
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
+ else:
+ from .models.mctct import MCTCTFeatureExtractor
+ from .models.speech_to_text import Speech2TextFeatureExtractor
- if is_speech_available() and is_sentencepiece_available():
- from .models.speech_to_text import Speech2TextProcessor
+ try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_tensorflow_text_objects import *
else:
+ from .models.bert import TFBertTokenizer
+
+ try:
+ if not (is_speech_available() and is_sentencepiece_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
from .utils.dummy_sentencepiece_and_speech_objects import *
+ else:
+ from .models.speech_to_text import Speech2TextProcessor
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_vision_objects import *
+ else:
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor
from .models.clip import CLIPFeatureExtractor, CLIPProcessor
@@ -2918,22 +3374,30 @@
from .models.deit import DeiTFeatureExtractor
from .models.detr import DetrFeatureExtractor
from .models.dpt import DPTFeatureExtractor
+ from .models.flava import FlavaFeatureExtractor, FlavaProcessor
from .models.glpn import GLPNFeatureExtractor
from .models.imagegpt import ImageGPTFeatureExtractor
- from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
- from .models.layoutxlm import LayoutXLMProcessor
+ from .models.layoutlmv2 import LayoutLMv2FeatureExtractor
+ from .models.layoutlmv3 import LayoutLMv3FeatureExtractor
+ from .models.levit import LevitFeatureExtractor
from .models.maskformer import MaskFormerFeatureExtractor
+ from .models.mobilevit import MobileViTFeatureExtractor
+ from .models.owlvit import OwlViTFeatureExtractor
from .models.perceiver import PerceiverFeatureExtractor
from .models.poolformer import PoolFormerFeatureExtractor
from .models.segformer import SegformerFeatureExtractor
+ from .models.videomae import VideoMAEFeatureExtractor
from .models.vilt import ViltFeatureExtractor, ViltProcessor
from .models.vit import ViTFeatureExtractor
from .models.yolos import YolosFeatureExtractor
- else:
- from .utils.dummy_vision_objects import *
# Modeling
- if is_timm_available() and is_vision_available():
+ try:
+ if not (is_timm_available() and is_vision_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_timm_objects import *
+ else:
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
@@ -2941,10 +3405,13 @@
DetrModel,
DetrPreTrainedModel,
)
- else:
- from .utils.dummy_timm_objects import *
- if is_scatter_available():
+ try:
+ if not is_scatter_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_scatter_objects import *
+ else:
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
@@ -2954,10 +3421,13 @@
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
- else:
- from .utils.dummy_scatter_objects import *
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_pt_objects import *
+ else:
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
@@ -2995,6 +3465,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
+ TypicalLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
@@ -3004,6 +3475,8 @@
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel
+
+ # PyTorch model imports
from .models.albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
@@ -3038,7 +3511,9 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
+ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
@@ -3063,7 +3538,9 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
+ AutoModelForVideoClassification,
AutoModelForVision2Seq,
+ AutoModelForVisualQuestionAnswering,
AutoModelWithLMHead,
)
from .models.bart import (
@@ -3142,6 +3619,14 @@
BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
)
+ from .models.bloom import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
@@ -3170,6 +3655,12 @@
CLIPTextModel,
CLIPVisionModel,
)
+ from .models.codegen import (
+ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ CodeGenForCausalLM,
+ CodeGenModel,
+ CodeGenPreTrainedModel,
+ )
from .models.convbert import (
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvBertForMaskedLM,
@@ -3195,6 +3686,12 @@
CTRLModel,
CTRLPreTrainedModel,
)
+ from .models.cvt import (
+ CVT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ CvtForImageClassification,
+ CvtModel,
+ CvtPreTrainedModel,
+ )
from .models.data2vec import (
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST,
DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -3230,6 +3727,7 @@
from .models.deberta_v2 import (
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
@@ -3304,6 +3802,16 @@
FlaubertModel,
FlaubertWithLMHeadModel,
)
+ from .models.flava import (
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaPreTrainedModel,
+ FlavaTextModel,
+ )
from .models.fnet import (
FNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FNetForMaskedLM,
@@ -3355,6 +3863,13 @@
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
)
+ from .models.gpt_neox import (
+ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GPTNeoXForCausalLM,
+ GPTNeoXLayer,
+ GPTNeoXModel,
+ GPTNeoXPreTrainedModel,
+ )
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
@@ -3363,6 +3878,13 @@
GPTJModel,
GPTJPreTrainedModel,
)
+ from .models.groupvit import (
+ GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GroupViTModel,
+ GroupViTPreTrainedModel,
+ GroupViTTextModel,
+ GroupViTVisionModel,
+ )
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
@@ -3404,6 +3926,14 @@
LayoutLMv2Model,
LayoutLMv2PreTrainedModel,
)
+ from .models.layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ LayoutLMv3PreTrainedModel,
+ )
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
@@ -3412,6 +3942,13 @@
LEDModel,
LEDPreTrainedModel,
)
+ from .models.levit import (
+ LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ LevitPreTrainedModel,
+ )
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
@@ -3423,12 +3960,23 @@
LongformerPreTrainedModel,
LongformerSelfAttention,
)
+ from .models.longt5 import (
+ LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ LongT5PreTrainedModel,
+ )
from .models.luke import (
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
+ LukeForMultipleChoice,
+ LukeForQuestionAnswering,
+ LukeForSequenceClassification,
+ LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
@@ -3462,6 +4010,7 @@
MBartModel,
MBartPreTrainedModel,
)
+ from .models.mctct import MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST, MCTCTForCTC, MCTCTModel, MCTCTPreTrainedModel
from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronBertForCausalLM,
@@ -3490,6 +4039,13 @@
MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert,
)
+ from .models.mobilevit import (
+ MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MobileViTForImageClassification,
+ MobileViTForSemanticSegmentation,
+ MobileViTModel,
+ MobileViTPreTrainedModel,
+ )
from .models.mpnet import (
MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
MPNetForMaskedLM,
@@ -3502,6 +4058,27 @@
MPNetPreTrainedModel,
)
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
+ from .models.mvp import (
+ MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MvpForCausalLM,
+ MvpForConditionalGeneration,
+ MvpForQuestionAnswering,
+ MvpForSequenceClassification,
+ MvpModel,
+ MvpPreTrainedModel,
+ )
+ from .models.nezha import (
+ NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ NezhaForMaskedLM,
+ NezhaForMultipleChoice,
+ NezhaForNextSentencePrediction,
+ NezhaForPreTraining,
+ NezhaForQuestionAnswering,
+ NezhaForSequenceClassification,
+ NezhaForTokenClassification,
+ NezhaModel,
+ NezhaPreTrainedModel,
+ )
from .models.nystromformer import (
NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
NystromformerForMaskedLM,
@@ -3522,6 +4099,21 @@
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
+ from .models.opt import (
+ OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ OPTForCausalLM,
+ OPTForSequenceClassification,
+ OPTModel,
+ OPTPreTrainedModel,
+ )
+ from .models.owlvit import (
+ OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ OwlViTForObjectDetection,
+ OwlViTModel,
+ OwlViTPreTrainedModel,
+ OwlViTTextModel,
+ OwlViTVisionModel,
+ )
from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
@@ -3684,6 +4276,7 @@
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel
from .models.splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
@@ -3707,6 +4300,13 @@
SwinModel,
SwinPreTrainedModel,
)
+ from .models.swinv2 import (
+ SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Swinv2ForImageClassification,
+ Swinv2ForMaskedImageModeling,
+ Swinv2Model,
+ Swinv2PreTrainedModel,
+ )
from .models.t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
@@ -3715,6 +4315,11 @@
T5PreTrainedModel,
load_tf_weights_in_t5,
)
+ from .models.trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TrajectoryTransformerModel,
+ TrajectoryTransformerPreTrainedModel,
+ )
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,
@@ -3749,12 +4354,20 @@
VanModel,
VanPreTrainedModel,
)
+ from .models.videomae import (
+ VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST,
+ VideoMAEForPreTraining,
+ VideoMAEForVideoClassification,
+ VideoMAEModel,
+ VideoMAEPreTrainedModel,
+ )
from .models.vilt import (
VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViltForImageAndTextRetrieval,
ViltForImagesAndTextClassification,
ViltForMaskedLM,
ViltForQuestionAnswering,
+ ViltForTokenClassification,
ViltLayer,
ViltModel,
ViltPreTrainedModel,
@@ -3797,6 +4410,16 @@
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
+ from .models.wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerPreTrainedModel,
+ )
from .models.wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
@@ -3895,12 +4518,16 @@
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer
- else:
- from .utils.dummy_pt_objects import *
# TensorFlow
- if is_tf_available():
-
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ # Import the same objects as dummies to get them in the namespace.
+ # They will raise an import error if the user tries to instantiate / use them.
+ from .utils.dummy_tf_objects import *
+ else:
from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
# Benchmarks
@@ -3931,6 +4558,8 @@
TFLayoutLMPreTrainedModel,
)
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list
+
+ # TensorFlow model imports
from .models.albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM,
@@ -3946,11 +4575,13 @@
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
@@ -3964,6 +4595,7 @@
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
+ TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
@@ -4038,6 +4670,7 @@
)
from .models.data2vec import (
TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
@@ -4059,6 +4692,14 @@
TFDebertaV2Model,
TFDebertaV2PreTrainedModel,
)
+ from .models.deit import (
+ TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFDeiTForImageClassification,
+ TFDeiTForImageClassificationWithTeacher,
+ TFDeiTForMaskedImageModeling,
+ TFDeiTModel,
+ TFDeiTPreTrainedModel,
+ )
from .models.distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
@@ -4193,8 +4834,15 @@
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
+ from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
+ from .models.regnet import (
+ TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFRegNetForImageClassification,
+ TFRegNetModel,
+ TFRegNetPreTrainedModel,
+ )
from .models.rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
@@ -4207,6 +4855,12 @@
TFRemBertModel,
TFRemBertPreTrainedModel,
)
+ from .models.resnet import (
+ TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFResNetForImageClassification,
+ TFResNetModel,
+ TFResNetPreTrainedModel,
+ )
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
@@ -4231,12 +4885,27 @@
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
+ from .models.segformer import (
+ TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSegformerDecodeHead,
+ TFSegformerForImageClassification,
+ TFSegformerForSemanticSegmentation,
+ TFSegformerModel,
+ TFSegformerPreTrainedModel,
+ )
from .models.speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration,
TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel,
)
+ from .models.swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinPreTrainedModel,
+ )
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
@@ -4308,13 +4977,14 @@
# Trainer
from .trainer_tf import TFTrainer
- else:
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
- from .utils.dummy_tf_objects import *
-
- if is_flax_available():
-
+ from .utils.dummy_flax_objects import *
+ else:
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
@@ -4327,6 +4997,8 @@
FlaxTopPLogitsWarper,
)
from .modeling_flax_utils import FlaxPreTrainedModel
+
+ # Flax model imports
from .models.albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
@@ -4443,6 +5115,7 @@
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
+ from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
from .models.mbart import (
FlaxMBartForConditionalGeneration,
@@ -4451,7 +5124,8 @@
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
- from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
+ from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
+ from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
FlaxRobertaForCausalLM,
@@ -4473,7 +5147,7 @@
FlaxRoFormerPreTrainedModel,
)
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
- from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
+ from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
@@ -4492,10 +5166,6 @@
FlaxXLMRobertaForTokenClassification,
FlaxXLMRobertaModel,
)
- else:
- # Import the same objects as dummies to get them in the namespace.
- # They will raise an import error if the user tries to instantiate / use them.
- from .utils.dummy_flax_objects import *
else:
import sys
diff --git a/src/transformers/activations.py b/src/transformers/activations.py
index fad8d1061347..5d413bba728b 100644
--- a/src/transformers/activations.py
+++ b/src/transformers/activations.py
@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False):
super().__init__()
- if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python:
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
@@ -110,7 +110,7 @@ class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
- if version.parse(torch.__version__) < version.parse("1.7"):
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
self.act = self._silu_python
else:
self.act = nn.functional.silu
@@ -130,7 +130,7 @@ class MishActivation(nn.Module):
def __init__(self):
super().__init__()
- if version.parse(torch.__version__) < version.parse("1.9"):
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py
index 8569c6e324e3..7f95e4b40b7c 100644
--- a/src/transformers/benchmark/benchmark.py
+++ b/src/transformers/benchmark/benchmark.py
@@ -96,7 +96,8 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_MAPPING[config.__class__](config)
@@ -151,7 +152,8 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
@@ -230,7 +232,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
if self.args.is_tpu:
# tpu
raise NotImplementedError(
- "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `--no-memory` or `args.memory=False`"
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
+ " `--no-memory` or `args.memory=False`"
)
elif self.args.is_gpu:
if not is_py3nvml_available():
@@ -241,7 +244,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
memory = "N/A"
else:
logger.info(
- "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
+ " on the same GPU."
)
# init nvml
nvml.nvmlInit()
diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py
index dbdf9d8a3673..2d759ac34256 100644
--- a/src/transformers/benchmark/benchmark_args.py
+++ b/src/transformers/benchmark/benchmark_args.py
@@ -24,7 +24,7 @@
if is_torch_available():
import torch
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
@@ -54,7 +54,8 @@ def __init__(self, **kwargs):
positive_arg = deprecated_arg[3:]
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
logger.warning(
- f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
+ f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
+ f" {positive_arg}={kwargs[positive_arg]}"
)
self.torchscript = kwargs.pop("torchscript", self.torchscript)
diff --git a/src/transformers/benchmark/benchmark_args_tf.py b/src/transformers/benchmark/benchmark_args_tf.py
index 7ec5054cb37c..8f3a9cea9465 100644
--- a/src/transformers/benchmark/benchmark_args_tf.py
+++ b/src/transformers/benchmark/benchmark_args_tf.py
@@ -51,7 +51,8 @@ def __init__(self, **kwargs):
positive_arg = deprecated_arg[3:]
kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
logger.warning(
- f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
+ f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
+ f" {positive_arg}={kwargs[positive_arg]}"
)
self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
self.device_idx = kwargs.pop("device_idx", self.device_idx)
diff --git a/src/transformers/benchmark/benchmark_args_utils.py b/src/transformers/benchmark/benchmark_args_utils.py
index b2f76f809f18..d9233906d281 100644
--- a/src/transformers/benchmark/benchmark_args_utils.py
+++ b/src/transformers/benchmark/benchmark_args_utils.py
@@ -43,7 +43,10 @@ class BenchmarkArguments:
models: List[str] = list_field(
default=[],
metadata={
- "help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
+ "help": (
+ "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
+ " of all available models"
+ )
},
)
@@ -87,7 +90,11 @@ class BenchmarkArguments:
multi_process: bool = field(
default=True,
metadata={
- "help": "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled for debugging / testing and on TPU."
+ "help": (
+ "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
+ " multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
+ " for debugging / testing and on TPU."
+ )
},
)
inference_time_csv_file: str = field(
@@ -118,7 +125,10 @@ class BenchmarkArguments:
only_pretrain_model: bool = field(
default=False,
metadata={
- "help": "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain model weights."
+ "help": (
+ "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
+ " model weights."
+ )
},
)
@@ -138,9 +148,10 @@ def to_json_string(self):
@property
def model_names(self):
- assert (
- len(self.models) > 0
- ), "Please make sure you provide at least one model name / model identifier, *e.g.* `--models bert-base-cased` or `args.models = ['bert-base-cased']."
+ assert len(self.models) > 0, (
+ "Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
+ " bert-base-cased` or `args.models = ['bert-base-cased']."
+ )
return self.models
@property
diff --git a/src/transformers/benchmark/benchmark_tf.py b/src/transformers/benchmark/benchmark_tf.py
index 0eb0db64a8d6..b5fd4b71b562 100644
--- a/src/transformers/benchmark/benchmark_tf.py
+++ b/src/transformers/benchmark/benchmark_tf.py
@@ -140,7 +140,8 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_MAPPING[config.__class__](config)
@@ -184,7 +185,8 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
@@ -239,15 +241,17 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
with self.args.strategy.scope():
try:
if self.args.trace_memory_line_by_line:
- assert (
- self.args.eager_mode
- ), "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory consumption line by line."
+ assert self.args.eager_mode, (
+ "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
+ " consumption line by line."
+ )
trace = start_memory_tracing("transformers")
if self.args.is_tpu:
# tpu
raise NotImplementedError(
- "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.memory=False`"
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking"
+ " with `args.memory=False`"
)
elif self.args.is_gpu:
# gpu
@@ -259,7 +263,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
memory = "N/A"
else:
logger.info(
- "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes"
+ " running on the same GPU."
)
# init nvml
nvml.nvmlInit()
@@ -274,7 +279,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
# cpu
if self.args.trace_memory_line_by_line:
logger.info(
- "When enabling line by line tracing, the max peak memory for CPU is inaccurate in TensorFlow."
+ "When enabling line by line tracing, the max peak memory for CPU is inaccurate in"
+ " TensorFlow."
)
memory = None
else:
diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py
index 7e738bb601cf..36fe5eb116cb 100644
--- a/src/transformers/benchmark/benchmark_utils.py
+++ b/src/transformers/benchmark/benchmark_utils.py
@@ -379,7 +379,7 @@ def start_memory_tracing(
devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
nvml.nvmlShutdown()
except (OSError, nvml.NVMLError):
- logger.warning("Error while initializing communication with GPU. " "We won't perform GPU memory tracing.")
+ logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
log_gpu = False
else:
log_gpu = is_torch_available() or is_tf_available()
@@ -626,7 +626,8 @@ def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig =
if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
logger.warning(
- "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
+ "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
+ " flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
)
self._print_fn = None
@@ -732,7 +733,8 @@ def run(self):
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if self.args.is_tpu:
self.print_fn(
- "TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured."
+ "TPU was used for inference. Note that the time after compilation stabilized (after ~10"
+ " inferences model.forward(..) calls) was measured."
)
if self.args.memory:
@@ -751,7 +753,8 @@ def run(self):
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if self.args.is_tpu:
self.print_fn(
- "TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured."
+ "TPU was used for training. Note that the time after compilation stabilized (after ~10 train"
+ " loss=model.forward(...) + loss.backward() calls) was measured."
)
if self.args.memory:
diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py
index 625af5001116..c49f3ad86904 100644
--- a/src/transformers/commands/add_new_model_like.py
+++ b/src/transformers/commands/add_new_model_like.py
@@ -18,6 +18,7 @@
import re
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
+from datetime import date
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
@@ -32,6 +33,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+CURRENT_YEAR = date.today().year
TRANSFORMERS_PATH = Path(__file__).parent.parent
REPO_PATH = TRANSFORMERS_PATH.parent.parent
@@ -421,6 +423,7 @@ def duplicate_module(
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
+ content = re.sub("# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
objects = parse_module_content(content)
# Loop and treat all objects
@@ -766,7 +769,9 @@ def clean_frameworks_in_init(
return
remove_pattern = "|".join(to_remove)
- re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
+ re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
+ re_try = re.compile(r"\s*try:")
+ re_else = re.compile(r"\s*else:")
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
with open(init_file, "r", encoding="utf-8") as f:
@@ -776,11 +781,15 @@ def clean_frameworks_in_init(
new_lines = []
idx = 0
while idx < len(lines):
- # Conditional imports
- if re_conditional_imports.search(lines[idx]) is not None:
+ # Conditional imports in try-except-else blocks
+ if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
+ # Remove the preceding `try:`
+ new_lines.pop()
idx += 1
- while is_empty_line(lines[idx]):
+ # Iterate until `else:`
+ while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
idx += 1
+ idx += 1
indent = find_indent(lines[idx])
while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
idx += 1
@@ -790,6 +799,7 @@ def clean_frameworks_in_init(
for framework in to_remove:
line = line.replace(f", is_{framework}_available", "")
line = line.replace(f"is_{framework}_available, ", "")
+ line = line.replace(f"is_{framework}_available,", "")
line = line.replace(f"is_{framework}_available", "")
if len(line.strip()) > 0:
@@ -834,14 +844,24 @@ def add_model_to_main_init(
new_lines = []
framework = None
while idx < len(lines):
+ new_framework = False
if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
framework = None
- elif lines[idx].lstrip().startswith("if is_torch_available"):
+ elif lines[idx].lstrip().startswith("if not is_torch_available"):
framework = "pt"
- elif lines[idx].lstrip().startswith("if is_tf_available"):
+ new_framework = True
+ elif lines[idx].lstrip().startswith("if not is_tf_available"):
framework = "tf"
- elif lines[idx].lstrip().startswith("if is_flax_available"):
+ new_framework = True
+ elif lines[idx].lstrip().startswith("if not is_flax_available"):
framework = "flax"
+ new_framework = True
+
+ if new_framework:
+ # For a new framework, we need to skip until the else: block to get where the imports are.
+ while lines[idx].strip() != "else:":
+ new_lines.append(lines[idx])
+ idx += 1
# Skip if we are in a framework not wanted.
if framework is not None and frameworks is not None and framework not in frameworks:
@@ -1055,6 +1075,7 @@ def duplicate_doc_file(
with open(doc_file, "r", encoding="utf-8") as f:
content = f.read()
+ content = re.sub(" There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
@@ -809,7 +826,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1326,7 +1343,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -1406,10 +1423,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py
index 45d29b6779ee..b71c37dc48db 100644
--- a/src/transformers/models/gpt2/modeling_tf_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py
@@ -20,7 +20,6 @@
import numpy as np
import tensorflow as tf
-from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
@@ -813,25 +812,21 @@ def get_output_embeddings(self):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
- def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
- # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
- # tests will need to be fixed after the change
-
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
+ if token_type_ids is not None:
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
+
+ position_ids = kwargs.get("position_ids", None)
+ attention_mask = kwargs.get("attention_mask", None)
- # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
- # for a future PR to not change too many things for now.
- # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
- position_ids = None
- attention_mask = None
- if use_xla:
- attention_mask = kwargs.get("attention_mask", None)
- if past is not None and attention_mask is not None:
- position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
- elif attention_mask is not None:
- position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
+ if attention_mask is not None and position_ids is None:
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
+ if past:
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
@@ -839,64 +834,9 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
+ "token_type_ids": token_type_ids,
}
- def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
- # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
- # quite some duplicated code patterns it seems
- # also the `attention_mask` is currently used in a somewhat hacky to
- # correctly influence the `past_key_values` - not sure if this is the way to go
- # Let's keep that for a future PR.
- past = outputs.past_key_values
- is_past_initialized = model_kwargs.pop("past", None) is not None
- attention_mask = model_kwargs.pop("attention_mask")
- batch_size = attention_mask.shape[0]
-
- if not is_past_initialized:
- # past[0].shape[3] is seq_length of prompt
- num_padding_values = max_length - past[0].shape[3] - 1
-
- padding_values = np.zeros((5, 2), dtype=np.int32)
- padding_values[3, 1] = num_padding_values
- padding_values = tf.constant(padding_values)
-
- new_past = list(past)
- for i in range(len(past)):
- new_past[i] = tf.pad(past[i], padding_values)
-
- # Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
- attention_mask = tf.concat(
- [
- attention_mask,
- tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
- tf.ones((batch_size, 1), dtype=attention_mask.dtype),
- ],
- axis=1,
- )
- else:
- new_past = [None for _ in range(len(past))]
- slice_start_base = tf.constant([0, 0, 0, 1, 0])
- attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
- # correct 5 here
- new_past_index = current_pos - 1
-
- for i in range(len(past)):
- update_slice = past[i][:, :, :, -1:]
- # Write the last slice to the first open location in the padded past array
- # and then truncate the last slice off the array
- new_past[i] = dynamic_update_slice(
- past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
- )
-
- update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
- attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)
-
- # set `attention_mask` and `past`
- model_kwargs["attention_mask"] = attention_mask
- model_kwargs["past"] = tuple(new_past)
-
- return model_kwargs
-
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@@ -1240,7 +1180,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py
index 6a6f49b1f988..b480eca0c062 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2.py
@@ -162,20 +162,26 @@ def __init__(
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
+ pad_token=None,
add_prefix_space=False,
+ add_bos_token=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
+ pad_token=pad_token,
add_prefix_space=add_prefix_space,
+ add_bos_token=add_bos_token,
**kwargs,
)
+ self.add_bos_token = add_bos_token
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
@@ -242,6 +248,19 @@ def bpe(self, token):
self.cache[token] = word
return word
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is None:
+ return output
+
+ return output + bos_token_ids + token_ids_1
+
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
@@ -278,7 +297,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
index e244a5d21e6f..ddd4ad56fde1 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
@@ -146,6 +146,17 @@ def __init__(
**kwargs,
)
+ if kwargs.pop("add_bos_token", False):
+ model_id = kwargs.pop("name_or_path", "")
+ raise ValueError(
+ "Currenty GPT2's fast tokenizer does NOT support adding a BOS token."
+ "Instead you should use GPT2's slow tokenizer class `GPT2Tokenizer` as follows: \n"
+ f"`GPT2Tokenizer.from_pretrained('{model_id}')`\nor\n"
+ f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
+ "This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
+ " so that the fast tokenizer works correctly."
+ )
+
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
diff --git a/src/transformers/models/gpt_neo/__init__.py b/src/transformers/models/gpt_neo/__init__.py
index d039b6f43974..b57f7c3f9760 100644
--- a/src/transformers/models/gpt_neo/__init__.py
+++ b/src/transformers/models/gpt_neo/__init__.py
@@ -17,14 +17,19 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_gpt_neo"] = [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
@@ -34,7 +39,12 @@
"load_tf_weights_in_gpt_neo",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_gpt_neo"] = [
"FlaxGPTNeoForCausalLM",
"FlaxGPTNeoModel",
@@ -45,7 +55,12 @@
if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
@@ -55,7 +70,12 @@
load_tf_weights_in_gpt_neo,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py
index dc47db0a8a19..00054a2c6bb0 100644
--- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py
@@ -261,8 +261,9 @@ def generate_dummy_inputs(
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
index 7ee1c17477eb..4a5fddd0a9d0 100644
--- a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
+++ b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
@@ -60,8 +60,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained mesh-tf model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained mesh-tf model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 9fcbf57c733b..c30db4e347f4 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -147,15 +147,16 @@ def __init__(self, config, attention_type):
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
- self.attn_dropout = nn.Dropout(config.attention_dropout)
- self.resid_dropout = nn.Dropout(config.resid_dropout)
+ self.attn_dropout = nn.Dropout(float(config.attention_dropout))
+ self.resid_dropout = nn.Dropout(float(config.resid_dropout))
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@@ -187,8 +188,12 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
- attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
@@ -289,7 +294,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 *
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
- self.dropout = nn.Dropout(config.resid_dropout)
+ self.dropout = nn.Dropout(float(config.resid_dropout))
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
@@ -357,6 +362,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt_neo
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTNeoBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -474,7 +480,7 @@ def __init__(self, config):
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
- self.drop = nn.Dropout(config.embed_dropout)
+ self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -541,7 +547,6 @@ def forward(
else:
past_length = past_key_values[0][0].size(-2)
- device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
@@ -564,7 +569,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -659,7 +664,7 @@ def custom_forward(*inputs):
class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"h\.\d+\.attn\.masked_bias",
- r"lm_head\.weight",
+ r"lm_head.weight",
r"h\.\d+\.attn\.attention\.bias",
]
_keys_to_ignore_on_save = [r"lm_head.weight"]
@@ -810,7 +815,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -883,10 +888,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
diff --git a/src/transformers/models/gpt_neox/__init__.py b/src/transformers/models/gpt_neox/__init__.py
new file mode 100644
index 000000000000..814fa9a30131
--- /dev/null
+++ b/src/transformers/models/gpt_neox/__init__.py
@@ -0,0 +1,78 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable
+
+
+_import_structure = {"configuration_gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"]}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_gpt_neox_fast"] = ["GPTNeoXTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_gpt_neox"] = [
+ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GPTNeoXForCausalLM",
+ "GPTNeoXLayer",
+ "GPTNeoXModel",
+ "GPTNeoXPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_gpt_neox import (
+ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GPTNeoXForCausalLM,
+ GPTNeoXLayer,
+ GPTNeoXModel,
+ GPTNeoXPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py
new file mode 100644
index 000000000000..8e906225c0d1
--- /dev/null
+++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py
@@ -0,0 +1,117 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" GPTNeoX model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "EleutherAI/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/config.json",
+ # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox
+}
+
+
+class GPTNeoXConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GPTNeoXModel`]. It is used to instantiate an
+ GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the GPTNeoX
+ [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50432):
+ Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPTNeoXModel`].
+ hidden_size (`int`, *optional*, defaults to 6144):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 44):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 24576):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ rotary_pct (`float`, *optional*, defaults to 0.25):
+ percentage of hidden dimensions to allocate to rotary embeddings
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
+ base for computing rotary embeddings frequency
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 1e-5):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXModel, GPTNeoXConfig
+
+ >>> # Initializing a GPTNeoX gpt-neox-20b style configuration
+ >>> configuration = GPTNeoXConfig()
+
+ >>> # Initializing a model from the gpt-neox-20b style configuration
+ >>> model = GPTNeoXModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "gpt_neox"
+
+ def __init__(
+ self,
+ vocab_size=50432,
+ hidden_size=6144,
+ num_hidden_layers=44,
+ num_attention_heads=64,
+ intermediate_size=24576,
+ hidden_act="gelu",
+ rotary_pct=0.25,
+ rotary_emb_base=10000,
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=0,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ **kwargs
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.rotary_pct = rotary_pct
+ self.rotary_emb_base = rotary_emb_base
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
new file mode 100755
index 000000000000..569ead7bdf3f
--- /dev/null
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -0,0 +1,673 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch GPTNeoX model."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from .configuration_gpt_neox import GPTNeoXConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "gpt-neox-20b"
+_CONFIG_FOR_DOC = "GPTNeoXConfig"
+_TOKENIZER_FOR_DOC = "GPTNeoXTokenizerFast"
+
+GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "EleutherAI/gpt-neox-20b",
+ # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox
+]
+
+
+class GPTNeoXPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPTNeoXConfig
+ base_model_prefix = "gpt_neox"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTNeoXLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GPTNeoXModel):
+ module.gradient_checkpointing = value
+
+
+class GPTNeoXAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_attention_heads
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
+ self.rotary_emb = RotaryEmbedding(
+ self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
+ )
+ self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
+ self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask=None,
+ layer_past=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ has_layer_past = layer_past is not None
+
+ # Compute QKV
+ # Attention heads [batch, seq_len, hidden_size]
+ # --> [batch, seq_len, (np * 3 * head_size)]
+ qkv = self.query_key_value(hidden_states)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = query[..., : self.rotary_ndims]
+ query_pass = query[..., self.rotary_ndims :]
+ key_rot = key[..., : self.rotary_ndims]
+ key_pass = key[..., self.rotary_ndims :]
+
+ # Compute token offset for rotary embeddings (when decoding)
+ seq_len = key.shape[-2]
+ offset = 0
+ if has_layer_past:
+ offset = layer_past[0].shape[-2]
+ seq_len += offset
+ cos, sin = self.rotary_emb(value, seq_len=seq_len)
+ query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
+ query = torch.cat((query, query_pass), dim=-1)
+ key = torch.cat((key, key_pass), dim=-1)
+
+ # Cache QKV values
+ if has_layer_past:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+ present = (key, value) if use_cache else None
+
+ # Compute attention
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ # Reshape outputs
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
+ attn_output = self.dense(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ @classmethod
+ def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Splits hidden dim into attn_head_size and num_attention_heads
+ """
+ # tensor: [bs, seq_len, hidden_size]
+ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(new_shape)
+ # -> [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3)
+ return tensor
+
+ @classmethod
+ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden dim
+ """
+ # tensor [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
+ # -> [bs, seq_len, hidden_size]
+ return tensor
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
+ # compute causal mask from causal mask buffer
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
+ key_length = key.size(-2)
+
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
+ attn_scores = torch.zeros(
+ batch_size * num_attention_heads,
+ query_length,
+ key_length,
+ dtype=query.dtype,
+ device=key.device,
+ )
+ attn_scores = torch.baddbmm(
+ attn_scores,
+ query,
+ key.transpose(1, 2),
+ beta=1.0,
+ alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
+ )
+ attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
+
+ mask_value = torch.finfo(attn_scores.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
+ attn_scores = torch.where(causal_mask, attn_scores, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_scores = attn_scores + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1)
+ attn_weights = attn_weights.to(value.dtype)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ return attn_output, attn_weights
+
+
+def attention_mask_func(attention_scores, ltor_mask):
+ attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
+ return attention_scores
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = emb.cos()[None, None, :, :]
+ self.sin_cached = emb.sin()[None, None, :, :]
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.cos_cached = emb.cos()[None, None, :, :]
+ self.sin_cached = emb.sin()[None, None, :, :]
+ return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
+ cos = cos[..., offset : q.shape[-2] + offset, :]
+ sin = sin[..., offset : q.shape[-2] + offset, :]
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class GPTNeoXMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class GPTNeoXLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = GPTNeoXAttention(config)
+ self.mlp = GPTNeoXMLP(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ layer_past=None,
+ output_attentions=False,
+ ):
+ residual = hidden_states
+ ln_out = self.input_layernorm(hidden_states)
+ attention_layer_outputs = self.attention(
+ ln_out,
+ attention_mask=attention_mask,
+ layer_past=layer_past,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attention_layer_outputs[1:]
+
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
+ hidden_states = mlp_output + attn_output + residual
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs # hidden_states, present, (attentions)
+
+
+GPT_NEOX_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`~GPTNeoXConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPT_NEOX_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`GPTNeoXTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT_NEOX_START_DOCSTRING,
+)
+class GPTNeoXModel(GPTNeoXPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_in
+
+ def set_input_embeddings(self, value):
+ self.embed_in = value
+
+ @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
+
+ # Attention mask.
+ if attention_mask is not None:
+ assert batch_size > 0, "batch_size has to be defined and > 0"
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ hidden_states = inputs_embeds
+
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+@add_start_docstrings(
+ """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
+)
+class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
+
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.gpt_neox = GPTNeoXModel(config)
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.embed_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.embed_out = new_embeddings
+
+ @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
+ only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
+ >>> import torch
+
+ >>> tokenizer = GPTNeoXTokenizer.from_pretrained("gpt-neox-20b")
+ >>> config = GPTNeoXConfig.from_pretrained("gpt-neox-20b")
+ >>> config.is_decoder = True
+ >>> model = GPTNeoXForCausalLM.from_pretrained("gpt-neox-20b", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.gpt_neox(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ lm_logits = self.embed_out(hidden_states)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shift_logits = lm_logits[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past and past[0] is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
+ )
+ return reordered_past
diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
new file mode 100644
index 000000000000..c08d533835d7
--- /dev/null
+++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for GPTNeoX."""
+import json
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from tokenizers import pre_tokenizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from transformers.pipelines.conversational import Conversation
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "tokenizer_file": {
+ "EleutherAI/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "gpt-neox-20b": 2048,
+}
+
+
+class GPTNeoXTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" GPT-NeoX-20B tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import GPTNeoXTokenizerFast
+ >>> tokenizer = GPTNeoXTokenizerFast.from_pretrained("gpt2")
+ >>> tokenizer("Hello world")['input_ids']
+ [15496, 995]
+ >>> tokenizer(" Hello world")['input_ids']
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPTNeoX tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
+ """This corresponds to DialoGPT variants of models."""
+ input_ids = []
+ for is_user, text in conversation.iter_texts():
+ input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
+
+ if len(input_ids) > self.model_max_length:
+ input_ids = input_ids[-self.model_max_length :]
+ return input_ids
diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py
index a6b144ab8251..d4c4e01a6ede 100644
--- a/src/transformers/models/gptj/__init__.py
+++ b/src/transformers/models/gptj/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
-}
+_import_structure = {"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
@@ -34,7 +43,12 @@
"GPTJPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_gptj"] = [
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
@@ -43,7 +57,12 @@
"TFGPTJPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_gptj"] = [
"FlaxGPTJForCausalLM",
"FlaxGPTJModel",
@@ -54,7 +73,12 @@
if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
@@ -64,7 +88,12 @@
GPTJPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
@@ -73,7 +102,12 @@
TFGPTJPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
else:
diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py
index 1fb6edd3db8e..c1f20a77134b 100644
--- a/src/transformers/models/gptj/configuration_gptj.py
+++ b/src/transformers/models/gptj/configuration_gptj.py
@@ -211,8 +211,9 @@ def generate_dummy_inputs(
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index d10c266d3f0e..cb05902ee422 100755
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
- x = torch.stack((-x2, x1), axis=-1)
+ x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
@@ -111,7 +111,8 @@ def __init__(self, config):
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+ f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
@@ -162,14 +163,19 @@ def _attn(
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
- attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
+
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
@@ -333,6 +339,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTJBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -603,7 +610,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -888,7 +895,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPTJ_START_DOCSTRING,
)
class GPTJForSequenceClassification(GPTJPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -967,10 +974,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
@@ -1015,7 +1022,7 @@ def forward(
GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py
index feaad22eff04..a1071408fb0d 100644
--- a/src/transformers/models/gptj/modeling_tf_gptj.py
+++ b/src/transformers/models/gptj/modeling_tf_gptj.py
@@ -60,14 +60,12 @@
]
-def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
- dim = shape_list(x)[-1]
- if seq_len is None:
- seq_len = shape_list(x)[seq_dim]
+def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor:
inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32)
- seq_len_range = tf.cast(tf.range(seq_len), tf.float32)
- sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32)
- return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype)
+ sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32)
+ sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)
+ out = tf.concat((sin, cos), axis=1)
+ return out
def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
@@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
return rotate_half_tensor
-def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor:
+def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor:
sin_pos, cos_pos = sincos
- sin_pos = tf.repeat(sin_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
- cos_pos = tf.repeat(cos_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
- return (x * cos_pos) + (rotate_every_two(x) * sin_pos)
+ sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3)
+ cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3)
+ return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
class TFGPTJAttention(tf.keras.layers.Layer):
@@ -93,7 +91,8 @@ def __init__(self, config: GPTJConfig, **kwargs):
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+ f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = self.head_dim**0.5
self.rotary_dim = config.rotary_dim
@@ -131,6 +130,8 @@ def __init__(self, config: GPTJConfig, **kwargs):
tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
(1, 1, self.max_positions, self.max_positions),
)
+ pos_embd_dim = self.rotary_dim or self.embed_dim
+ self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim)
def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)
@@ -206,8 +207,9 @@ def _attn(
def call(
self,
hidden_states: tf.Tensor,
- attention_mask: Optional[tf.Tensor] = None,
layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
+ attention_mask: Optional[tf.Tensor] = None,
+ position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
@@ -220,13 +222,8 @@ def call(
key = self._split_heads(key, True)
value = self._split_heads(value, False)
- seq_len = shape_list(key)[1]
- offset = 0
-
- if layer_past is not None:
- offset = shape_list(layer_past[0])[-2]
- seq_len += offset
-
+ sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype)
+ sincos = tf.split(sincos, 2, axis=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
@@ -234,16 +231,14 @@ def call(
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
- sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
- k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
- q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
+ k_rot = apply_rotary_pos_emb(k_rot, sincos)
+ q_rot = apply_rotary_pos_emb(q_rot, sincos)
key = tf.concat((k_rot, k_pass), axis=-1)
query = tf.concat((q_rot, q_pass), axis=-1)
else:
- sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
- key = apply_rotary_pos_emb(key, sincos, offset=offset)
- query = apply_rotary_pos_emb(query, sincos, offset=offset)
+ key = apply_rotary_pos_emb(key, sincos)
+ query = apply_rotary_pos_emb(query, sincos)
key = tf.transpose(key, (0, 2, 1, 3))
query = tf.transpose(query, (0, 2, 1, 3))
@@ -309,6 +304,7 @@ def call(
hidden_states: tf.Tensor,
layer_past: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
+ position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
@@ -316,9 +312,10 @@ def call(
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
- hidden_states,
+ hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
+ position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
@@ -465,12 +462,13 @@ def call(
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(
- hidden_states,
- layer_past,
- attention_mask,
- head_mask[i],
- use_cache,
- output_attentions,
+ hidden_states=hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
training=training,
)
@@ -728,25 +726,21 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
- def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
- # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
- # tests will need to be fixed after the change
-
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
+ if token_type_ids is not None:
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
+
+ position_ids = kwargs.get("position_ids", None)
+ attention_mask = kwargs.get("attention_mask", None)
- # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
- # for a future PR to not change too many things for now.
- # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
- position_ids = None
- attention_mask = None
- if use_xla:
- attention_mask = kwargs.get("attention_mask", None)
- if past is not None and attention_mask is not None:
- position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
- elif attention_mask is not None:
- position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
+ if attention_mask is not None and position_ids is None:
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
+ if past:
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
@@ -754,6 +748,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
+ "token_type_ids": token_type_ids,
}
@unpack_inputs
@@ -929,7 +924,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/groupvit/__init__.py b/src/transformers/models/groupvit/__init__.py
new file mode 100644
index 000000000000..8d902054975b
--- /dev/null
+++ b/src/transformers/models/groupvit/__init__.py
@@ -0,0 +1,71 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_groupvit": [
+ "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "GroupViTConfig",
+ "GroupViTTextConfig",
+ "GroupViTVisionConfig",
+ ],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_groupvit"] = [
+ "GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GroupViTModel",
+ "GroupViTPreTrainedModel",
+ "GroupViTTextModel",
+ "GroupViTVisionModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_groupvit import (
+ GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ GroupViTConfig,
+ GroupViTTextConfig,
+ GroupViTVisionConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_groupvit import (
+ GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GroupViTModel,
+ GroupViTPreTrainedModel,
+ GroupViTTextModel,
+ GroupViTVisionModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/groupvit/configuration_groupvit.py b/src/transformers/models/groupvit/configuration_groupvit.py
new file mode 100644
index 000000000000..8940cf40b9f1
--- /dev/null
+++ b/src/transformers/models/groupvit/configuration_groupvit.py
@@ -0,0 +1,345 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" GroupViT model configuration"""
+
+import copy
+import os
+from typing import Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "nvidia/groupvit-gcc-yfcc": "https://huggingface.co/nvidia/groupvit-gcc-yfcc/resolve/main/config.json",
+}
+
+
+class GroupViTTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GroupViTTextModel`]. It is used to instantiate an
+ GroupViT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the GroupViT
+ [nvidia/groupvit-gcc-yfcc](https://huggingface.co/nvidia/groupvit-gcc-yfcc) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 49408):
+ Vocabulary size of the GroupViT text model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`GroupViTModel`].
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 77):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import GroupViTTextConfig, GroupViTTextModel
+
+ >>> # Initializing a GroupViTTextModel with nvidia/groupvit-gcc-yfcc style configuration
+ >>> configuration = GroupViTTextConfig()
+
+ >>> model = GroupViTTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "groupvit_text_model"
+
+ def __init__(
+ self,
+ vocab_size=49408,
+ hidden_size=256,
+ intermediate_size=1024,
+ num_hidden_layers=12,
+ num_attention_heads=4,
+ max_position_embeddings=77,
+ hidden_act="quick_gelu",
+ layer_norm_eps=0.00001,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from GroupViTConfig
+ if config_dict.get("model_type") == "groupvit":
+ config_dict = config_dict["text_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class GroupViTVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GroupViTVisionModel`]. It is used to instantiate
+ an GroupViT model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GroupViT
+ [nvidia/groupvit-gcc-yfcc](https://huggingface.co/nvidia/groupvit-gcc-yfcc) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 384):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ depths (`List[int]`, *optional*, defaults to [6, 3, 3]):
+ The number of layers in each encoder block.
+ num_group_tokens (`List[int]`, *optional*, defaults to [64, 8, 0]):
+ The number of group tokens for each stage.
+ num_output_groups (`List[int]`, *optional*, defaults to [64, 8, 0]):
+ The number of output groups for each stage, 0 means no group.
+ num_attention_heads (`int`, *optional*, defaults to 6):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import GroupViTVisionConfig, GroupViTVisionModel
+
+ >>> # Initializing a GroupViTVisionModel with nvidia/groupvit-gcc-yfcc style configuration
+ >>> configuration = GroupViTVisionConfig()
+
+ >>> model = GroupViTVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "groupvit_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=384,
+ intermediate_size=1536,
+ depths=[6, 3, 3],
+ num_hidden_layers=12,
+ num_group_tokens=[64, 8, 0],
+ num_output_groups=[64, 8, 8],
+ num_attention_heads=6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-5,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ assign_eps=1.0,
+ assign_mlp_ratio=[0.5, 4],
+ qkv_bias=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.depths = depths
+ if num_hidden_layers != sum(depths):
+ logger.warning(
+ f"Manually setting num_hidden_layers to {num_hidden_layers}, but we expect num_hidden_layers ="
+ f" sum(depth) = {sum(depths)}"
+ )
+ self.num_hidden_layers = num_hidden_layers
+ self.num_group_tokens = num_group_tokens
+ self.num_output_groups = num_output_groups
+ self.num_attention_heads = num_attention_heads
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.assign_eps = assign_eps
+ self.assign_mlp_ratio = assign_mlp_ratio
+ self.qkv_bias = qkv_bias
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from GroupViTConfig
+ if config_dict.get("model_type") == "groupvit":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class GroupViTConfig(PretrainedConfig):
+ r"""
+ [`GroupViTConfig`] is the configuration class to store the configuration of a [`GroupViTModel`]. It is used to
+ instantiate a GroupViT model according to the specified arguments, defining the text model and vision model
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`GroupViTTextConfig`].
+ vision_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`GroupViTVisionConfig`].
+ projection_dim (`int`, *optional*, defaults to 256):
+ Dimentionality of text and vision projection layers.
+ projection_intermediate_dim (`int`, *optional*, defaults to 4096):
+ Dimentionality of intermediate layer of text and vision projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The inital value of the *logit_scale* parameter. Default is used as per the original GroupViT
+ implementation.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ """
+
+ model_type = "groupvit"
+ is_composition = True
+
+ def __init__(
+ self,
+ text_config_dict=None,
+ vision_config_dict=None,
+ projection_dim=256,
+ projection_intermediate_dim=4096,
+ logit_scale_init_value=2.6592,
+ **kwargs
+ ):
+ super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
+
+ if text_config_dict is None:
+ text_config_dict = {}
+ logger.info("text_config_dict is None. Initializing the GroupViTTextConfig with default values.")
+
+ if vision_config_dict is None:
+ vision_config_dict = {}
+ logger.info("vision_config_dict is None. initializing the GroupViTVisionConfig with default values.")
+
+ self.text_config = GroupViTTextConfig(**text_config_dict)
+ self.vision_config = GroupViTVisionConfig(**vision_config_dict)
+
+ self.projection_dim = projection_dim
+ self.projection_intermediate_dim = projection_intermediate_dim
+ self.logit_scale_init_value = logit_scale_init_value
+ self.initializer_range = 0.02
+ self.initializer_factor = 1.0
+ self.output_segmentation = False
+
+ @classmethod
+ def from_text_vision_configs(cls, text_config: GroupViTTextConfig, vision_config: GroupViTVisionConfig, **kwargs):
+ r"""
+ Instantiate a [`GroupViTConfig`] (or a derived class) from groupvit text model configuration and groupvit
+ vision model configuration.
+
+ Returns:
+ [`GroupViTConfig`]: An instance of a configuration object
+ """
+
+ return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["text_config"] = self.text_config.to_dict()
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py b/src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py
new file mode 100644
index 000000000000..e83bdd35cb37
--- /dev/null
+++ b/src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py
@@ -0,0 +1,217 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Convert GroupViT checkpoints from the original repository.
+
+URL: https://github.com/NVlabs/GroupViT
+"""
+
+import argparse
+
+import torch
+from PIL import Image
+
+import requests
+from transformers import CLIPProcessor, GroupViTConfig, GroupViTModel
+
+
+def rename_key(name):
+ # vision encoder
+ if "img_encoder.pos_embed" in name:
+ name = name.replace("img_encoder.pos_embed", "vision_model.embeddings.position_embeddings")
+ if "img_encoder.patch_embed.proj" in name:
+ name = name.replace("img_encoder.patch_embed.proj", "vision_model.embeddings.patch_embeddings.projection")
+ if "img_encoder.patch_embed.norm" in name:
+ name = name.replace("img_encoder.patch_embed.norm", "vision_model.embeddings.layernorm")
+ if "img_encoder.layers" in name:
+ name = name.replace("img_encoder.layers", "vision_model.encoder.stages")
+ if "blocks" in name and "res" not in name:
+ name = name.replace("blocks", "layers")
+ if "attn" in name and "pre_assign" not in name:
+ name = name.replace("attn", "self_attn")
+ if "proj" in name and "self_attn" in name and "text" not in name:
+ name = name.replace("proj", "out_proj")
+ if "pre_assign_attn.attn.proj" in name:
+ name = name.replace("pre_assign_attn.attn.proj", "pre_assign_attn.attn.out_proj")
+ if "norm1" in name:
+ name = name.replace("norm1", "layer_norm1")
+ if "norm2" in name and "pre_assign" not in name:
+ name = name.replace("norm2", "layer_norm2")
+ if "img_encoder.norm" in name:
+ name = name.replace("img_encoder.norm", "vision_model.layernorm")
+ # text encoder
+ if "text_encoder.token_embedding" in name:
+ name = name.replace("text_encoder.token_embedding", "text_model.embeddings.token_embedding")
+ if "text_encoder.positional_embedding" in name:
+ name = name.replace("text_encoder.positional_embedding", "text_model.embeddings.position_embedding.weight")
+ if "text_encoder.transformer.resblocks." in name:
+ name = name.replace("text_encoder.transformer.resblocks.", "text_model.encoder.layers.")
+ if "ln_1" in name:
+ name = name.replace("ln_1", "layer_norm1")
+ if "ln_2" in name:
+ name = name.replace("ln_2", "layer_norm2")
+ if "c_fc" in name:
+ name = name.replace("c_fc", "fc1")
+ if "c_proj" in name:
+ name = name.replace("c_proj", "fc2")
+ if "text_encoder" in name:
+ name = name.replace("text_encoder", "text_model")
+ if "ln_final" in name:
+ name = name.replace("ln_final", "final_layer_norm")
+ # projection layers
+ if "img_projector.linear_hidden." in name:
+ name = name.replace("img_projector.linear_hidden.", "visual_projection.")
+ if "img_projector.linear_out." in name:
+ name = name.replace("img_projector.linear_out.", "visual_projection.3.")
+ if "text_projector.linear_hidden" in name:
+ name = name.replace("text_projector.linear_hidden", "text_projection")
+ if "text_projector.linear_out" in name:
+ name = name.replace("text_projector.linear_out", "text_projection.3")
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, config):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if "qkv" in key:
+ # weights and biases of the key, value and query projections of vision encoder's attention layers require special treatment:
+ # we need to split them up into separate matrices/vectors
+ key_split = key.split(".")
+ stage_num, layer_num = int(key_split[2]), int(key_split[4])
+ dim = config.vision_config.hidden_size
+ if "weight" in key:
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.q_proj.weight"
+ ] = val[:dim, :]
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.k_proj.weight"
+ ] = val[dim : dim * 2, :]
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.v_proj.weight"
+ ] = val[-dim:, :]
+ else:
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.q_proj.bias"
+ ] = val[:dim]
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.k_proj.bias"
+ ] = val[dim : dim * 2]
+ orig_state_dict[
+ f"vision_model.encoder.stages.{stage_num}.layers.{layer_num}.self_attn.v_proj.bias"
+ ] = val[-dim:]
+ elif "in_proj" in key:
+ # weights and biases of the key, value and query projections of text encoder's attention layers require special treatment:
+ # we need to split them up into separate matrices/vectors
+ key_split = key.split(".")
+ layer_num = int(key_split[3])
+ dim = config.text_config.hidden_size
+ if "weight" in key:
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
+ dim : dim * 2, :
+ ]
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
+ else:
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
+ orig_state_dict[f"text_model.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
+ else:
+ new_name = rename_key(key)
+ # squeeze if necessary
+ if (
+ "text_projection.0" in new_name
+ or "text_projection.3" in new_name
+ or "visual_projection.0" in new_name
+ or "visual_projection.3" in new_name
+ ):
+ orig_state_dict[new_name] = val.squeeze_()
+ else:
+ orig_state_dict[new_name] = val
+
+ return orig_state_dict
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_groupvit_checkpoint(
+ checkpoint_path, pytorch_dump_folder_path, model_name="groupvit-gcc-yfcc", push_to_hub=False
+):
+ """
+ Copy/paste/tweak model's weights to the Transformers design.
+ """
+ config = GroupViTConfig()
+ model = GroupViTModel(config).eval()
+
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ new_state_dict = convert_state_dict(state_dict, config)
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
+ assert missing_keys == ["text_model.embeddings.position_ids"]
+ assert (unexpected_keys == ["multi_label_logit_scale"]) or (len(unexpected_keys) == 0)
+
+ # verify result
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+ image = prepare_img()
+ inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt")
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ if model_name == "groupvit-gcc-yfcc":
+ expected_logits = torch.tensor([[13.3523, 6.3629]])
+ elif model_name == "groupvit-gcc-redcaps":
+ expected_logits = torch.tensor([[16.1873, 8.6230]])
+ else:
+ raise ValueError(f"Model name {model_name} not supported.")
+ assert torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)
+
+ processor.save_pretrained(pytorch_dump_folder_path)
+ model.save_pretrained(pytorch_dump_folder_path)
+ print("Successfully saved processor and model to", pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing to the hub...")
+ processor.push_to_hub(model_name, organization="nielsr")
+ model.push_to_hub(model_name, organization="nielsr")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to dump the processor and PyTorch model."
+ )
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to GroupViT checkpoint")
+ parser.add_argument(
+ "--model_name",
+ default="groupvit-gccy-fcc",
+ type=str,
+ help="Name of the model. Expecting either 'groupvit-gcc-yfcc' or 'groupvit-gcc-redcaps'",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the converted model and processor to the š¤ hub using the provided `model_name`.",
+ )
+ args = parser.parse_args()
+
+ convert_groupvit_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.model_name, args.push_to_hub)
diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py
new file mode 100644
index 000000000000..1073d4bfea87
--- /dev/null
+++ b/src/transformers/models/groupvit/modeling_groupvit.py
@@ -0,0 +1,1604 @@
+# coding=utf-8
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch GroupViT model."""
+
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "nvidia/groupvit-gcc-yfcc"
+
+GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "nvidia/groupvit-gcc-yfcc",
+ # See all GroupViT models at https://huggingface.co/models?filter=groupvit
+]
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/GroupViT.html
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
+def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.T)
+ return (caption_loss + image_loss) / 2.0
+
+
+def hard_softmax(logits: torch.Tensor, dim: int):
+ y_soft = logits.softmax(dim)
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+
+ return ret
+
+
+def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
+ # more stable https://github.com/pytorch/pytorch/issues/41663
+ gumbel_dist = torch.distributions.gumbel.Gumbel(
+ torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
+ torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
+ )
+ gumbels = gumbel_dist.sample(logits.shape)
+
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = gumbels.softmax(dim)
+
+ if hard:
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+ else:
+ # Reparametrization trick.
+ ret = y_soft
+ return ret
+
+
+def resize_attention_map(attentions, height, width, align_corners=False):
+ """
+ Args:
+ attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
+ height (`int`): height of the output attention map
+ width (`int`): width of the output attention map
+ align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
+
+ Returns:
+ `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
+ """
+
+ scale = (height * width // attentions.shape[2]) ** 0.5
+ if height > width:
+ feat_width = int(np.round(width / scale))
+ feat_height = attentions.shape[2] // feat_width
+ else:
+ feat_height = int(np.round(height / scale))
+ feat_width = attentions.shape[2] // feat_height
+
+ batch_size = attentions.shape[0]
+ groups = attentions.shape[1] # number of group token
+ # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]
+ attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)
+ attentions = nn.functional.interpolate(
+ attentions, size=(height, width), mode="bilinear", align_corners=align_corners
+ )
+ return attentions
+
+
+def get_grouping_from_attentions(attentions, hw_shape):
+ """
+ Args:
+ attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
+ hw_shape (`tuple(int)`): height and width of the output attention map
+ Returns:
+ `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
+ """
+
+ attn_maps = []
+ with torch.no_grad():
+ prev_attn_masks = None
+ for attn_masks in attentions:
+ # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
+ attn_masks = attn_masks.permute(0, 2, 1).contiguous()
+ if prev_attn_masks is None:
+ prev_attn_masks = attn_masks
+ else:
+ prev_attn_masks = prev_attn_masks @ attn_masks
+ # [batch_size, heightxwidth, num_groups] -> [batch_size, num_groups, heightxwidth] -> [batch_size, num_groups, height, width]
+ cur_attn_map = resize_attention_map(prev_attn_masks.permute(0, 2, 1).contiguous(), *hw_shape)
+ attn_maps.append(cur_attn_map)
+
+ # [batch_size, num_groups, height, width]
+ final_grouping = attn_maps[-1]
+
+ return final_grouping
+
+
+class GroupViTCrossAttentionLayer(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.attn = GroupViTAttention(config)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = GroupViTMLP(config)
+ self.norm_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, query, key):
+ x = query
+ x = x + self.attn(query, encoder_hidden_states=key)[0]
+ x = x + self.mlp(self.norm2(x))
+ x = self.norm_post(x)
+ return x
+
+
+class GroupViTAssignAttention(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.scale = config.hidden_size**-0.5
+
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.assign_eps = config.assign_eps
+
+ def get_attn(self, attn, gumbel=True, hard=True):
+
+ if gumbel and self.training:
+ attn = gumbel_softmax(attn, dim=-2, hard=hard)
+ else:
+ if hard:
+ attn = hard_softmax(attn, dim=-2)
+ else:
+ attn = nn.functional.softmax(attn, dim=-2)
+
+ return attn
+
+ def forward(self, query, key):
+ value = key
+ # [batch_size, query_length, channels]
+ query = self.q_proj(query)
+
+ # [batch_size, key_length, channels]
+ key = self.k_proj(key)
+
+ # [batch_size, key_length, channels]
+ value = self.v_proj(value)
+
+ # [batch_size, query_length, key_length]
+ raw_attn = (query @ key.transpose(-2, -1)) * self.scale
+
+ attn = self.get_attn(raw_attn)
+ soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
+
+ attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
+
+ out = attn @ value
+
+ out = self.proj(out)
+
+ return out, soft_attn
+
+
+class GroupViTTokenAssign(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig, num_group_token, num_output_group):
+ super().__init__()
+ self.num_output_group = num_output_group
+ # norm on group_tokens
+ self.norm_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ assign_mlp_ratio = (
+ config.assign_mlp_ratio
+ if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
+ else (config.assign_mlp_ratio, config.assign_mlp_ratio)
+ )
+ tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
+ self.mlp_inter = GroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group)
+ self.norm_post_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ # norm on x
+ self.norm_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pre_assign_attn = GroupViTCrossAttentionLayer(config)
+
+ self.assign = GroupViTAssignAttention(config)
+ self.norm_new_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp_channels = GroupViTMLP(config, config.hidden_size, channels_dim, config.hidden_size)
+
+ def project_group_token(self, group_tokens):
+ """
+ Args:
+ group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]
+
+ Returns:
+ projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
+ """
+ # [B, num_output_groups, C] <- [B, num_group_tokens, C]
+ projected_group_tokens = self.mlp_inter(group_tokens)
+ projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
+ return projected_group_tokens
+
+ def forward(self, image_tokens, group_tokens):
+ """
+ Args:
+ image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
+ group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
+ """
+
+ group_tokens = self.norm_tokens(group_tokens)
+ image_tokens = self.norm_x(image_tokens)
+ # [batch_size, num_output_groups, channels]
+ projected_group_tokens = self.project_group_token(group_tokens)
+ projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
+ new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
+ new_image_tokens += projected_group_tokens
+
+ new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
+
+ return new_image_tokens, attention
+
+
+@dataclass
+class GroupViTModelOutput(ModelOutput):
+ """
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of
+ [`GroupViTTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of
+ [`GroupViTVisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`GroupViTTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`GroupViTVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: torch.FloatTensor = None
+ logits_per_text: torch.FloatTensor = None
+ segmentation_logits: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ image_embeds: torch.FloatTensor = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class GroupViTPatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ num_channels: int = 3,
+ embed_dim: int = 768,
+ ):
+ super().__init__()
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+class GroupViTVisionEmbeddings(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+
+ self.patch_embeddings = GroupViTPatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
+ self.dropout = nn.Dropout(config.dropout)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ npatch = embeddings.shape[1]
+ if npatch == self.position_embeddings.shape[1] and height == width:
+ return self.position_embeddings
+ patch_pos_embed = self.position_embeddings
+ num_original_pos_embed = patch_pos_embed.shape[1]
+ dim = embeddings.shape[-1]
+ feat_height = height // self.config.patch_size
+ feat_width = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ feat_height, feat_width = feat_height + 0.1, feat_width + 0.1
+ original_height = original_width = math.sqrt(num_original_pos_embed)
+ reshaped_patch_pos_embed = patch_pos_embed.reshape(1, int(original_height), int(original_width), dim).permute(
+ 0, 3, 1, 2
+ )
+ scale_factor = (feat_height / original_height, feat_width / original_width)
+ patch_pos_embed = nn.functional.interpolate(
+ reshaped_patch_pos_embed,
+ scale_factor=scale_factor,
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ embeddings = self.layernorm(embeddings)
+
+ batch_size, seq_len, _ = embeddings.size()
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->GroupViT
+class GroupViTTextEmbeddings(nn.Module):
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class GroupViTStage(nn.Module):
+ """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
+
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ depth: int,
+ num_prev_group_token: int,
+ num_group_token: int,
+ num_output_group: int,
+ ):
+ super().__init__()
+ self.depth = depth
+ self.num_group_token = num_group_token
+ if num_group_token > 0:
+ self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
+ else:
+ self.group_token = None
+ self.gradient_checkpointing = False
+ self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
+
+ if num_group_token > 0:
+ self.downsample = GroupViTTokenAssign(
+ config=config,
+ num_group_token=num_group_token,
+ num_output_group=num_output_group,
+ )
+ else:
+ self.downsample = None
+
+ if num_prev_group_token > 0 and num_group_token > 0:
+ self.group_projector = nn.Sequential(
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
+ GroupViTMixerMLP(config, num_prev_group_token, config.hidden_size // 2, num_group_token),
+ )
+ else:
+ self.group_projector = None
+
+ @property
+ def with_group_token(self):
+ return self.group_token is not None
+
+ def split_x(self, x):
+ if self.with_group_token:
+ return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
+ else:
+ return x, None
+
+ def concat_x(self, x: torch.Tensor, group_token: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if group_token is None:
+ return x
+ return torch.cat([x, group_token], dim=1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prev_group_token: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the grouping tensors of Grouping block.
+ """
+ if self.with_group_token:
+ group_token = self.group_token.expand(hidden_states.size(0), -1, -1)
+ if self.group_projector is not None:
+ group_token = group_token + self.group_projector(prev_group_token)
+ else:
+ group_token = None
+
+ x = hidden_states
+
+ cat_x = self.concat_x(x, group_token)
+ for layer in self.layers:
+ layer_out = layer(cat_x, attention_mask=None, causal_attention_mask=None)
+ cat_x = layer_out[0]
+
+ x, group_token = self.split_x(cat_x)
+
+ attention = None
+ if self.downsample is not None:
+ x, attention = self.downsample(x, group_token)
+
+ outputs = (x, group_token)
+ if output_attentions:
+ outputs = outputs + (attention,)
+
+ return outputs
+
+
+class GroupViTMLP(nn.Module):
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ hidden_size: Optional[int] = None,
+ intermediate_size: Optional[int] = None,
+ output_size: Optional[int] = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ hidden_size = hidden_size if hidden_size is not None else config.hidden_size
+ intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ output_size = output_size if output_size is not None else hidden_size
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
+ self.fc2 = nn.Linear(intermediate_size, output_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class GroupViTMixerMLP(GroupViTMLP):
+ def forward(self, x):
+ x = super().forward(x.transpose(1, 2))
+ return x.transpose(1, 2)
+
+
+class GroupViTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ is_cross_attention = encoder_hidden_states is not None
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ if is_cross_attention:
+ key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
+ else:
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GroupViT
+class GroupViTEncoderLayer(nn.Module):
+ def __init__(self, config: GroupViTConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = GroupViTAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim)
+ self.mlp = GroupViTMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class GroupViTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GroupViTConfig
+ base_model_prefix = "groupvit"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+
+ init_range = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=init_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ factor = self.config.initializer_factor
+ if isinstance(module, GroupViTTextEmbeddings):
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ elif isinstance(module, GroupViTAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, GroupViTMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ )
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)):
+ module.gradient_checkpointing = value
+
+
+GROUPVIT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GROUPVIT_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+GROUPVIT_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+GROUPVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`CLIPFeatureExtractor`]. See
+ [`CLIPFeatureExtractor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class GroupViTVisionEncoder(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.stages = nn.ModuleList(
+ [
+ GroupViTStage(
+ config=config,
+ depth=config.depths[i],
+ num_group_token=config.num_group_tokens[i],
+ num_output_group=config.num_output_groups[i],
+ num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
+ )
+ for i in range(len(config.depths))
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ all_hidden_states = () if output_hidden_states else None
+ all_groupings = () if output_attentions else None
+
+ group_tokens = None
+
+ for i, stage in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = stage(hidden_states, group_tokens, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ group_tokens = layer_outputs[1]
+
+ if output_attentions and layer_outputs[2] is not None:
+ all_groupings = all_groupings + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
+ )
+
+
+class GroupViTTextEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
+ [`GroupViTEncoderLayer`].
+
+ Args:
+ config: GroupViTTextConfig
+ """
+
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder, CLIP_TEXT->GROUPVIT_TEXT
+class GroupViTTextTransformer(nn.Module):
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = GroupViTTextEmbeddings(config)
+ self.encoder = GroupViTTextEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ bsz, seq_len = input_shape
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
+ hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def _build_causal_attention_mask(self, bsz, seq_len, dtype):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
+ mask.triu_(1) # zero out the lower diagonal
+ mask = mask.unsqueeze(1) # expand mask
+ return mask
+
+
+class GroupViTTextModel(GroupViTPreTrainedModel):
+ config_class = GroupViTTextConfig
+
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__(config)
+ self.text_model = GroupViTTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import CLIPTokenizer, GroupViTTextModel
+
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class GroupViTVisionTransformer(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = GroupViTVisionEmbeddings(config)
+ self.encoder = GroupViTVisionEncoder(config)
+ self.layernorm = nn.LayerNorm(embed_dim)
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTVisionConfig)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ hidden_states=hidden_states,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # normalize the last hidden state
+ last_hidden_state = self.layernorm(last_hidden_state)
+ pooled_output = last_hidden_state.mean(dim=1)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class GroupViTVisionModel(GroupViTPreTrainedModel):
+ config_class = GroupViTVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__(config)
+ self.vision_model = GroupViTVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> GroupViTPatchEmbeddings:
+ return self.vision_model.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTVisionConfig)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GroupViTVisionModel
+
+ >>> processor = AutoPProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(GROUPVIT_START_DOCSTRING)
+class GroupViTModel(GroupViTPreTrainedModel):
+ config_class = GroupViTConfig
+
+ def __init__(self, config: GroupViTConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, GroupViTTextConfig):
+ raise ValueError(
+ "config.text_config is expected to be of type GroupViTTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, GroupViTVisionConfig):
+ raise ValueError(
+ "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.projection_intermediate_dim = config.projection_intermediate_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = GroupViTTextTransformer(text_config)
+ self.vision_model = GroupViTVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Sequential(
+ nn.Linear(self.vision_embed_dim, self.projection_intermediate_dim, bias=True),
+ nn.BatchNorm1d(self.projection_intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(self.text_embed_dim, self.projection_intermediate_dim, bias=True),
+ nn.BatchNorm1d(self.projection_intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
+ )
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`GroupViTTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import CLIPTokenizer, GroupViTModel
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_outputs[1]
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`GroupViTVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GroupViTModel
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+ image_features = self.visual_projection(pooled_output)
+
+ return image_features
+
+ @add_start_docstrings_to_model_forward(GROUPVIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=GroupViTModelOutput, config_class=GroupViTConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_segmentation: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, GroupViTModelOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GroupViTModel
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
+ ... )
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_segmentation = (
+ output_segmentation if output_segmentation is not None else self.config.output_segmentation
+ )
+ if output_segmentation:
+ output_attentions = True
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.T
+
+ seg_logits = None
+ if output_segmentation:
+ # grouped features
+ # [batch_size_image, num_group, hidden_size]
+ image_group_embeds = vision_outputs[0]
+ # [batch_size_image*num_group, hidden_size]
+ image_group_embeds = self.visual_projection(image_group_embeds.reshape(-1, image_group_embeds.shape[-1]))
+ if output_hidden_states:
+ attentions = vision_outputs[3]
+ else:
+ attentions = vision_outputs[2]
+ # [batch_size_image, num_group, height, width]
+ grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
+
+ # normalized features
+ image_group_embeds = image_group_embeds / image_group_embeds.norm(dim=-1, keepdim=True)
+ # [batch_size_image x num_group, batch_size_text]
+ logits_per_image_group = torch.matmul(image_group_embeds, text_embeds.t()) * logit_scale
+ # [batch_size_image, batch_size_text, num_group]
+ logits_per_image_group = logits_per_image_group.reshape(
+ image_embeds.shape[0], -1, text_embeds.shape[0]
+ ).permute(0, 2, 1)
+
+ # [batch_size_image, batch_size_text, height x width]
+ flatten_grouping = grouping.reshape(grouping.shape[0], grouping.shape[1], -1)
+
+ # [batch_size_image, batch_size_text, height, width]
+ seg_logits = torch.matmul(logits_per_image_group, flatten_grouping) * logit_scale
+ seg_logits = seg_logits.reshape(
+ seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]
+ )
+
+ loss = None
+ if return_loss:
+ loss = groupvit_loss(logits_per_text)
+
+ if not return_dict:
+ if seg_logits is not None:
+ output = (
+ logits_per_image,
+ logits_per_text,
+ seg_logits,
+ text_embeds,
+ image_embeds,
+ text_outputs,
+ vision_outputs,
+ )
+ else:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return GroupViTModelOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ segmentation_logits=seg_logits,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
diff --git a/src/transformers/models/herbert/__init__.py b/src/transformers/models/herbert/__init__.py
index 4cd458b4e843..ef9d47535e5f 100644
--- a/src/transformers/models/herbert/__init__.py
+++ b/src/transformers/models/herbert/__init__.py
@@ -18,21 +18,29 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available
-_import_structure = {
- "tokenization_herbert": ["HerbertTokenizer"],
-}
+_import_structure = {"tokenization_herbert": ["HerbertTokenizer"]}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_herbert_fast"] = ["HerbertTokenizerFast"]
if TYPE_CHECKING:
from .tokenization_herbert import HerbertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_herbert_fast import HerbertTokenizerFast
else:
diff --git a/src/transformers/models/hubert/__init__.py b/src/transformers/models/hubert/__init__.py
index 59f848c11872..bd415e49a150 100644
--- a/src/transformers/models/hubert/__init__.py
+++ b/src/transformers/models/hubert/__init__.py
@@ -17,15 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
-_import_structure = {
- ".wav2vec2.feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
- "configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
-}
+_import_structure = {"configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_hubert"] = [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
@@ -35,7 +37,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_hubert"] = [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
@@ -44,10 +51,14 @@
]
if TYPE_CHECKING:
- from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
@@ -56,7 +67,12 @@
HubertPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py
index 9b104aa9c528..be2e6bbf4c71 100644
--- a/src/transformers/models/hubert/configuration_hubert.py
+++ b/src/transformers/models/hubert/configuration_hubert.py
@@ -82,10 +82,10 @@ class HubertConfig(PretrainedConfig):
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
@@ -233,10 +233,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
index c1963faa73b3..d7ba74fedae7 100644
--- a/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
+++ b/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
@@ -51,9 +51,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -121,28 +122,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
index dee823e094d6..9a70fb6db710 100644
--- a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
@@ -64,9 +64,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -134,28 +135,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index 5af0197fb95c..d6cb6b8e0599 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -174,7 +174,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -488,7 +488,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -504,7 +505,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -525,7 +527,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -657,10 +660,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -745,10 +750,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py
index 540090871feb..f078b5d0cfc7 100644
--- a/src/transformers/models/hubert/modeling_tf_hubert.py
+++ b/src/transformers/models/hubert/modeling_tf_hubert.py
@@ -95,12 +95,14 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[i]] = input
else:
raise ValueError(
- f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
+ f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[i]}."
)
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
- "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
+ " instead.",
FutureWarning,
)
@@ -108,7 +110,8 @@ def input_values_processing(func, config, input_values, **kwargs):
if "decoder_cached_states" in input_values:
warnings.warn(
- "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
@@ -128,7 +131,8 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[0]] = input_values
else:
raise ValueError(
- f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
+ f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[0]}."
)
for name in parameter_names:
@@ -199,7 +203,7 @@ def _compute_mask_indices(
Computes random mask spans for a given shape
Args:
- shape: the the shape for which to compute masks.
+ shape: the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob:
@@ -219,15 +223,17 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
- num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
- num_masked_spans = max(num_masked_spans, min_masks)
+ num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
+ num_masked_spans = tf.maximum(num_masked_spans, min_masks)
+ num_masked_spans = tf.cast(num_masked_spans, tf.int32)
# make sure num masked indices <= sequence_length
- if num_masked_spans * mask_length > sequence_length:
- num_masked_spans = sequence_length // mask_length
+ num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
+ num_masked_spans = tf.squeeze(num_masked_spans)
# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
@@ -251,14 +257,14 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
- tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
+ tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)
return spec_aug_mask
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -408,9 +414,11 @@ def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
- "Axis " + str(self.axis) + " of "
- "input tensor should have a defined dimension "
- "but the layer received an input with shape " + str(input_shape) + "."
+ "Axis "
+ + str(self.axis)
+ + " of input tensor should have a defined dimension but the layer received an input with shape "
+ + str(input_shape)
+ + "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
@@ -424,22 +432,27 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") cannot be "
- "more than the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") cannot be more than the number of channels ("
+ + str(dim)
+ + ")."
)
if dim % self.groups != 0:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") must be a "
- "multiple of the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") must be a multiple of the number of channels ("
+ + str(dim)
+ + ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
- "You are trying to normalize your batch axis. Do you want to "
- "use tf.layer.batch_normalization instead"
+ "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
@@ -809,7 +822,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -819,7 +835,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -835,7 +854,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -852,7 +874,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -1295,7 +1320,15 @@ def __init__(self, config, *inputs, **kwargs):
"to train/fine-tine this model, you need a GPU or a TPU"
)
- @tf.function
+ @tf.function(
+ input_signature=[
+ {
+ "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
+ }
+ ]
+ )
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)
@@ -1487,10 +1520,11 @@ def call(
return outputs
def serving_output(self, output):
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
-
- return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ return TFBaseModelOutput(
+ last_hidden_state=output.last_hidden_state, hidden_states=hidden_states, attentions=attentions
+ )
@add_start_docstrings(
@@ -1578,9 +1612,8 @@ def call(
>>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
- >>> # wrap processor as target processor to encode labels
- >>> with processor.as_target_processor():
- ... labels = processor(transcription, return_tensors="tf").input_values
+ >>> # Pass the transcription as text to encode labels
+ >>> labels = processor(text=transcription, return_tensors="tf").input_values
>>> loss = model(input_values, labels=labels).loss
```"""
@@ -1661,6 +1694,6 @@ def call(
)
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
- return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
diff --git a/src/transformers/models/ibert/__init__.py b/src/transformers/models/ibert/__init__.py
index e941b88f256e..0480da8c47fe 100644
--- a/src/transformers/models/ibert/__init__.py
+++ b/src/transformers/models/ibert/__init__.py
@@ -18,14 +18,17 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
-}
+_import_structure = {"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertForMaskedLM",
@@ -40,7 +43,12 @@
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_ibert import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertForMaskedLM,
diff --git a/src/transformers/models/ibert/configuration_ibert.py b/src/transformers/models/ibert/configuration_ibert.py
index 17f6d37e7d46..32d4d2e56a80 100644
--- a/src/transformers/models/ibert/configuration_ibert.py
+++ b/src/transformers/models/ibert/configuration_ibert.py
@@ -29,7 +29,9 @@
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"kssteven/ibert-roberta-base": "https://huggingface.co/kssteven/ibert-roberta-base/resolve/main/config.json",
"kssteven/ibert-roberta-large": "https://huggingface.co/kssteven/ibert-roberta-large/resolve/main/config.json",
- "kssteven/ibert-roberta-large-mnli": "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json",
+ "kssteven/ibert-roberta-large-mnli": (
+ "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/ibert/quant_modules.py b/src/transformers/models/ibert/quant_modules.py
index e6eab6ce6201..fa657924645e 100644
--- a/src/transformers/models/ibert/quant_modules.py
+++ b/src/transformers/models/ibert/quant_modules.py
@@ -150,7 +150,7 @@ def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, c
def __repr__(self):
return (
f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
- f"quant_mode: {self.activation_bit}, Act_min: {self.x_min.item():.2f}, "
+ f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
f"Act_max: {self.x_max.item():.2f})"
)
diff --git a/src/transformers/models/imagegpt/__init__.py b/src/transformers/models/imagegpt/__init__.py
index f82d0cb989ec..ecf7ba9408d1 100644
--- a/src/transformers/models/imagegpt/__init__.py
+++ b/src/transformers/models/imagegpt/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"],
-}
+_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_imagegpt"] = ["ImageGPTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_imagegpt"] = [
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ImageGPTForCausalImageModeling",
@@ -42,10 +50,20 @@
if TYPE_CHECKING:
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_imagegpt import ImageGPTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_imagegpt import (
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
ImageGPTForCausalImageModeling,
diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py
index 22186a6159e3..e71ea4a272c2 100755
--- a/src/transformers/models/imagegpt/modeling_imagegpt.py
+++ b/src/transformers/models/imagegpt/modeling_imagegpt.py
@@ -21,12 +21,18 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from ...pytorch_utils import (
+ Conv1D,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_or_equal_than_1_6,
+ prune_conv1d_layer,
+)
+
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
is_amp_available = True
from torch.cuda.amp import autocast
else:
@@ -39,7 +45,6 @@
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig
@@ -200,7 +205,8 @@ def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
@@ -252,7 +258,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
- attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
@@ -303,7 +313,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
- attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
@@ -699,14 +713,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
@@ -764,7 +778,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1010,14 +1024,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
@@ -1086,7 +1100,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
@@ -1143,14 +1157,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
diff --git a/src/transformers/models/layoutlm/__init__.py b/src/transformers/models/layoutlm/__init__.py
index b77edddc4d7e..a7ccae38e89e 100644
--- a/src/transformers/models/layoutlm/__init__.py
+++ b/src/transformers/models/layoutlm/__init__.py
@@ -18,9 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
-from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
-from .tokenization_layoutlm import LayoutLMTokenizer
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -28,10 +32,20 @@
"tokenization_layoutlm": ["LayoutLMTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutlm_fast"] = ["LayoutLMTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_layoutlm"] = [
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMForMaskedLM",
@@ -41,7 +55,12 @@
"LayoutLMPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_layoutlm"] = [
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM",
@@ -57,10 +76,20 @@
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
from .tokenization_layoutlm import LayoutLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutlm_fast import LayoutLMTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
@@ -69,7 +98,12 @@
LayoutLMModel,
LayoutLMPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py
index 9b77b2ce3f93..94100791d39f 100644
--- a/src/transformers/models/layoutlm/configuration_layoutlm.py
+++ b/src/transformers/models/layoutlm/configuration_layoutlm.py
@@ -27,8 +27,12 @@
logger = logging.get_logger(__name__)
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/config.json",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/config.json",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/config.json"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py
index 174813ffb216..e3a625416a7d 100644
--- a/src/transformers/models/layoutlm/modeling_layoutlm.py
+++ b/src/transformers/models/layoutlm/modeling_layoutlm.py
@@ -398,7 +398,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -799,12 +800,12 @@ def forward(
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None:
- bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+ bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
if head_mask is not None:
if head_mask.dim() == 1:
diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
index b184cb352e20..d15fc29b7366 100644
--- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
+++ b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
@@ -453,8 +453,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm.py b/src/transformers/models/layoutlm/tokenization_layoutlm.py
index 6ef9a9c3a005..1cd0a5f6e087 100644
--- a/src/transformers/models/layoutlm/tokenization_layoutlm.py
+++ b/src/transformers/models/layoutlm/tokenization_layoutlm.py
@@ -25,8 +25,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py b/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
index 90ba0a94feab..a614c3e61559 100644
--- a/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
+++ b/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
@@ -26,12 +26,20 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/layoutlmv2/__init__.py b/src/transformers/models/layoutlmv2/__init__.py
index 9f7a8dae39ac..beaacb815843 100644
--- a/src/transformers/models/layoutlmv2/__init__.py
+++ b/src/transformers/models/layoutlmv2/__init__.py
@@ -18,22 +18,43 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
_import_structure = {
"configuration_layoutlmv2": ["LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv2Config"],
+ "processing_layoutlmv2": ["LayoutLMv2Processor"],
"tokenization_layoutlmv2": ["LayoutLMv2Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutlmv2_fast"] = ["LayoutLMv2TokenizerFast"]
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_layoutlmv2"] = ["LayoutLMv2FeatureExtractor"]
- _import_structure["processing_layoutlmv2"] = ["LayoutLMv2Processor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_layoutlmv2"] = [
"LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMv2ForQuestionAnswering",
@@ -46,16 +67,31 @@
if TYPE_CHECKING:
from .configuration_layoutlmv2 import LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv2Config
+ from .processing_layoutlmv2 import LayoutLMv2Processor
from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutlmv2_fast import LayoutLMv2TokenizerFast
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor
- from .processing_layoutlmv2 import LayoutLMv2Processor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
diff --git a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
index 12fe27f1a17e..cd05819e479a 100644
--- a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
@@ -46,11 +46,11 @@ def normalize_box(box, width, height):
]
-def apply_tesseract(image: Image.Image, lang: Optional[str]):
+def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR
- data = pytesseract.image_to_data(image, lang=lang, output_type="dict")
+ data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
@@ -100,9 +100,12 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
if `do_resize` is set to `True`.
apply_ocr (`bool`, *optional*, defaults to `True`):
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
- ocr_lang (`Optional[str]`, *optional*):
+ ocr_lang (`str`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
+ tesseract_config (`str`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'.
@@ -112,13 +115,23 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
model_input_names = ["pixel_values"]
- def __init__(self, do_resize=True, size=224, resample=Image.BILINEAR, apply_ocr=True, ocr_lang=None, **kwargs):
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BILINEAR,
+ apply_ocr=True,
+ ocr_lang=None,
+ tesseract_config="",
+ **kwargs
+ ):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
@@ -201,7 +214,7 @@ def __call__(
words_batch = []
boxes_batch = []
for image in images:
- words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang)
+ words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
index 269e951ea00d..be31af99d6df 100755
--- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
@@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch LayoutLMv2 model."""
-
import math
from typing import Optional, Tuple, Union
@@ -179,7 +178,9 @@ def forward(
attention_scores += rel_pos
if self.has_spatial_attention_bias:
attention_scores += rel_2d_pos
- attention_scores = attention_scores.float().masked_fill_(attention_mask.to(torch.bool), float("-inf"))
+ attention_scores = attention_scores.float().masked_fill_(
+ attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
+ )
attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@@ -804,6 +805,16 @@ def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape)
return visual_bbox
+ def _get_input_shape(self, input_ids=None, inputs_embeds=None):
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ return input_ids.size()
+ elif inputs_embeds is not None:
+ return inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -821,45 +832,49 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
- Returns:
+ Return:
Examples:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2Model
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2Model, set_seed
>>> from PIL import Image
+ >>> import torch
+ >>> from datasets import load_dataset
+
+ >>> set_seed(88)
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image_path = dataset["test"][0]["file"]
+ >>> image = Image.open(image_path).convert("RGB")
>>> encoding = processor(image, return_tensors="pt")
>>> outputs = model(**encoding)
>>> last_hidden_states = outputs.last_hidden_state
- ```"""
+
+ >>> last_hidden_states.shape
+ torch.Size([1, 342, 768])
+ ```
+ """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
-
+ input_shape = self._get_input_shape(input_ids, inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
visual_shape = list(input_shape)
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
visual_shape = torch.Size(visual_shape)
- final_shape = list(input_shape)
+ # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
+ final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
final_shape[1] += visual_shape[1]
final_shape = torch.Size(final_shape)
@@ -906,7 +921,7 @@ def forward(
extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
if head_mask is not None:
if head_mask.dim() == 1:
@@ -990,25 +1005,37 @@ def forward(
Returns:
- Examples:
+ Example:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForSequenceClassification
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForSequenceClassification, set_seed
>>> from PIL import Image
>>> import torch
+ >>> from datasets import load_dataset
- >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> model = LayoutLMv2ForSequenceClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> set_seed(88)
+
+ >>> dataset = load_dataset("rvl_cdip", split="train", streaming=True)
+ >>> data = next(iter(dataset))
+ >>> image = data["image"].convert("RGB")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+ >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
+ ... )
>>> encoding = processor(image, return_tensors="pt")
- >>> sequence_label = torch.tensor([1])
+ >>> sequence_label = torch.tensor([data["label"]])
>>> outputs = model(**encoding, labels=sequence_label)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
+
+ >>> loss, logits = outputs.loss, outputs.logits
+ >>> predicted_idx = logits.argmax(dim=-1).item()
+ >>> predicted_answer = dataset.info.features["label"].names[4]
+ >>> predicted_idx, predicted_answer
+ (4, 'advertisement')
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1157,26 +1184,48 @@ def forward(
Returns:
- Examples:
+ Example:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification, set_seed
>>> from PIL import Image
+ >>> from datasets import load_dataset
- >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
- >>> model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> set_seed(88)
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
- >>> words = ["hello", "world"]
- >>> boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes
- >>> word_labels = [0, 1]
+ >>> datasets = load_dataset("nielsr/funsd", split="test")
+ >>> labels = datasets.features["ner_tags"].feature.names
+ >>> id2label = {v: k for v, k in enumerate(labels)}
- >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+ >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
+ >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
+ ... )
+
+ >>> data = datasets[0]
+ >>> image = Image.open(data["image_path"]).convert("RGB")
+ >>> words = data["words"]
+ >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
+ >>> word_labels = data["ner_tags"]
+ >>> encoding = processor(
+ ... image,
+ ... words,
+ ... boxes=boxes,
+ ... word_labels=word_labels,
+ ... padding="max_length",
+ ... truncation=True,
+ ... return_tensors="pt",
+ ... )
>>> outputs = model(**encoding)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
+ >>> logits, loss = outputs.logits, outputs.loss
+
+ >>> predicted_token_class_ids = logits.argmax(-1)
+ >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
+ >>> predicted_tokens_classes[:5]
+ ['B-ANSWER', 'B-HEADER', 'B-HEADER', 'B-HEADER', 'B-HEADER']
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1273,28 +1322,49 @@ def forward(
Returns:
- Examples:
+ Example:
+
+ In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
+ a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForQuestionAnswering
- >>> from PIL import Image
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForQuestionAnswering, set_seed
>>> import torch
+ >>> from PIL import Image
+ >>> from datasets import load_dataset
+ >>> set_seed(88)
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
- >>> question = "what's his name?"
-
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image_path = dataset["test"][0]["file"]
+ >>> image = Image.open(image_path).convert("RGB")
+ >>> question = "When is coffee break?"
>>> encoding = processor(image, question, return_tensors="pt")
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
-
- >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- >>> start_scores = outputs.start_logits
- >>> end_scores = outputs.end_logits
- ```"""
+
+ >>> outputs = model(**encoding)
+ >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
+ >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
+ >>> predicted_start_idx, predicted_end_idx
+ (154, 287)
+
+ >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
+ >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
+ >>> predicted_answer # results are not very good without further fine-tuning
+ 'council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public ...
+ ```
+
+ ```python
+ >>> target_start_index = torch.tensor([7])
+ >>> target_end_index = torch.tensor([14])
+ >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
+ >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
+ >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
+ >>> predicted_answer_span_start, predicted_answer_span_end
+ (154, 287)
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
diff --git a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
index 449eb4770aaf..57f0b78aed1b 100644
--- a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
@@ -86,10 +86,12 @@ def __call__(
if self.feature_extractor.apply_ocr and (word_labels is not None):
raise ValueError(
- "You cannot provide word labels "
- "if you initialized the feature extractor with apply_ocr set to True."
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
)
+ if return_overflowing_tokens is True and return_offsets_mapping is False:
+ raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
+
# first, apply the feature extractor
features = self.feature_extractor(images=images, return_tensors=return_tensors)
@@ -122,6 +124,37 @@ def __call__(
)
# add pixel values
- encoded_inputs["image"] = features.pop("pixel_values")
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["image"] = images
return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
index b750ede1850b..db934e5e8725 100644
--- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
@@ -38,8 +38,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlmv2-large-uncased": "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlmv2-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/vocab.txt"
+ ),
}
}
@@ -105,53 +109,61 @@
- `'np'`: Return Numpy `np.ndarray` objects.
"""
-
LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
- add_special_tokens (`bool`, *optional*, defaults to `True`):
- Whether or not to encode the sequences with the special tokens relative to their model.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
- Activates and controls padding. Accepts the following values:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
- Activates and controls truncation. Accepts the following values:
-
- - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
- to the maximum acceptable input length for the model if that argument is not provided. This will
- truncate token by token, removing a token from the longest sequence in the pair if a pair of
- sequences (or a batch of pairs) is provided.
- - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
- maximum acceptable input length for the model if that argument is not provided. This will only
- truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
- - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
- maximum acceptable input length for the model if that argument is not provided. This will only
- truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
- - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
- greater than the model maximum admissible input size).
- max_length (`int`, *optional*):
- Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
- `None`, this will use the predefined model maximum length if a maximum length is required by one of the
- truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
- truncation/padding to a maximum length will be deactivated.
- stride (`int`, *optional*, defaults to 0):
- If set to a number along with `max_length`, the overflowing tokens returned when
- `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
- returned to provide some overlap between truncated and overflowing sequences. The value of this
- argument defines the number of overlapping tokens.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
- the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **bbox** -- List of bounding boxes to be fed to a model.
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified).
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`).
"""
@@ -255,8 +267,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -495,20 +507,23 @@ def _is_valid_text_input(t):
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
words = text if text_pair is None else text_pair
- assert boxes is not None, "You must provide corresponding bounding boxes"
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
if is_batched:
- assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
for words_example, boxes_example in zip(words, boxes):
- assert len(words_example) == len(
- boxes_example
- ), "You must provide as many words as there are bounding boxes"
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
else:
- assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
@@ -1200,16 +1215,17 @@ def truncate_sequences(
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
- error_msg + "Please select another truncation strategy than "
+ error_msg
+ + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
- f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
- f"truncation strategy. So the returned list will always be empty even if some "
- f"tokens have been removed."
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
@@ -1231,7 +1247,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (
diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
index 2cc0de63add0..b61cf5ef7633 100644
--- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
+++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
@@ -47,10 +47,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/tokenizer.json",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
@@ -256,20 +260,23 @@ def _is_valid_text_input(t):
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
words = text if text_pair is None else text_pair
- assert boxes is not None, "You must provide corresponding bounding boxes"
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
if is_batched:
- assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
for words_example, boxes_example in zip(words, boxes):
- assert len(words_example) == len(
- boxes_example
- ), "You must provide as many words as there are bounding boxes"
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
else:
- assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
diff --git a/src/transformers/models/layoutlmv3/__init__.py b/src/transformers/models/layoutlmv3/__init__.py
new file mode 100644
index 000000000000..cfa26057e87b
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/__init__.py
@@ -0,0 +1,115 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {
+ "configuration_layoutlmv3": [
+ "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "LayoutLMv3Config",
+ "LayoutLMv3OnnxConfig",
+ ],
+ "processing_layoutlmv3": ["LayoutLMv3Processor"],
+ "tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_layoutlmv3_fast"] = ["LayoutLMv3TokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_layoutlmv3"] = [
+ "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LayoutLMv3ForQuestionAnswering",
+ "LayoutLMv3ForSequenceClassification",
+ "LayoutLMv3ForTokenClassification",
+ "LayoutLMv3Model",
+ "LayoutLMv3PreTrainedModel",
+ ]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_layoutlmv3"] = ["LayoutLMv3FeatureExtractor"]
+
+
+if TYPE_CHECKING:
+ from .configuration_layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ LayoutLMv3Config,
+ LayoutLMv3OnnxConfig,
+ )
+ from .processing_layoutlmv3 import LayoutLMv3Processor
+ from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ LayoutLMv3PreTrainedModel,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_layoutlmv3 import LayoutLMv3FeatureExtractor
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
new file mode 100644
index 000000000000..d9ddde6289c9
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
@@ -0,0 +1,294 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LayoutLMv3 model configuration"""
+
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Any, Mapping, Optional
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...onnx.utils import compute_effective_axis_dimension
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...processing_utils import ProcessorMixin
+ from ...utils import TensorType
+
+
+logger = logging.get_logger(__name__)
+
+LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
+}
+
+
+class LayoutLMv3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LayoutLMv3Model`]. It is used to instantiate an
+ LayoutLMv3 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LayoutLMv3
+ [microsoft/layoutlmv3-base](https://huggingface.co/microsoft/layoutlmv3-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the LayoutLMv3 model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`LayoutLMv3Model`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv3Model`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum value that the 2D position embedding might ever be used with. Typically set this to something
+ large just in case (e.g., 1024).
+ coordinate_size (`int`, *optional*, defaults to `128`):
+ Dimension of the coordinate embeddings.
+ shape_size (`int`, *optional*, defaults to `128`):
+ Dimension of the width and height embeddings.
+ has_relative_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a relative attention bias in the self-attention mechanism.
+ rel_pos_bins (`int`, *optional*, defaults to 32):
+ The number of relative position bins to be used in the self-attention mechanism.
+ max_rel_pos (`int`, *optional*, defaults to 128):
+ The maximum number of relative positions to be used in the self-attention mechanism.
+ max_rel_2d_pos (`int`, *optional*, defaults to 256):
+ The maximum number of relative 2D positions in the self-attention mechanism.
+ rel_2d_pos_bins (`int`, *optional*, defaults to 64):
+ The number of 2D relative position bins in the self-attention mechanism.
+ has_spatial_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a spatial attention bias in the self-attention mechanism.
+ visual_embed (`bool`, *optional*, defaults to `True`):
+ Whether or not to add patch embeddings.
+ input_size (`int`, *optional*, defaults to `224`):
+ The size (resolution) of the images.
+ num_channels (`int`, *optional*, defaults to `3`):
+ The number of channels of the images.
+ patch_size (`int`, *optional*, defaults to `16`)
+ The size (resolution) of the patches.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Example:
+
+ ```python
+ >>> from transformers import LayoutLMv3Model, LayoutLMv3Config
+
+ >>> # Initializing a LayoutLMv3 microsoft/layoutlmv3-base style configuration
+ >>> configuration = LayoutLMv3Config()
+
+ >>> # Initializing a model from the microsoft/layoutlmv3-base style configuration
+ >>> model = LayoutLMv3Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "layoutlmv3"
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ max_2d_position_embeddings=1024,
+ coordinate_size=128,
+ shape_size=128,
+ has_relative_attention_bias=True,
+ rel_pos_bins=32,
+ max_rel_pos=128,
+ rel_2d_pos_bins=64,
+ max_rel_2d_pos=256,
+ has_spatial_attention_bias=True,
+ text_embed=True,
+ visual_embed=True,
+ input_size=224,
+ num_channels=3,
+ patch_size=16,
+ classifier_dropout=None,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ hidden_act=hidden_act,
+ hidden_dropout_prob=hidden_dropout_prob,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size,
+ initializer_range=initializer_range,
+ layer_norm_eps=layer_norm_eps,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.max_2d_position_embeddings = max_2d_position_embeddings
+ self.coordinate_size = coordinate_size
+ self.shape_size = shape_size
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.rel_pos_bins = rel_pos_bins
+ self.max_rel_pos = max_rel_pos
+ self.has_spatial_attention_bias = has_spatial_attention_bias
+ self.rel_2d_pos_bins = rel_2d_pos_bins
+ self.max_rel_2d_pos = max_rel_2d_pos
+ self.text_embed = text_embed
+ self.visual_embed = visual_embed
+ self.input_size = input_size
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.classifier_dropout = classifier_dropout
+
+
+class LayoutLMv3OnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.12")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ # The order of inputs is different for question answering and sequence classification
+ if self.task in ["question-answering", "sequence-classification"]:
+ return OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "sequence"}),
+ ("attention_mask", {0: "batch", 1: "sequence"}),
+ ("bbox", {0: "batch", 1: "sequence"}),
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+ else:
+ return OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "sequence"}),
+ ("bbox", {0: "batch", 1: "sequence"}),
+ ("attention_mask", {0: "batch", 1: "sequence"}),
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-5
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 12
+
+ def generate_dummy_inputs(
+ self,
+ processor: "ProcessorMixin",
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional["TensorType"] = None,
+ num_channels: int = 3,
+ image_width: int = 40,
+ image_height: int = 40,
+ ) -> Mapping[str, Any]:
+ """
+ Generate inputs to provide to the ONNX exporter for the specific framework
+
+ Args:
+ processor ([`ProcessorMixin`]):
+ The processor associated with this model configuration.
+ batch_size (`int`, *optional*, defaults to -1):
+ The batch size to export the model for (-1 means dynamic axis).
+ seq_length (`int`, *optional*, defaults to -1):
+ The sequence length to export the model for (-1 means dynamic axis).
+ is_pair (`bool`, *optional*, defaults to `False`):
+ Indicate if the input is a pair (sentence 1, sentence 2).
+ framework (`TensorType`, *optional*, defaults to `None`):
+ The framework (PyTorch or TensorFlow) that the processor will generate tensors for.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of channels of the generated images.
+ image_width (`int`, *optional*, defaults to 40):
+ The width of the generated images.
+ image_height (`int`, *optional*, defaults to 40):
+ The height of the generated images.
+
+ Returns:
+ Mapping[str, Any]: holding the kwargs to provide to the model's forward function
+ """
+
+ # A dummy image is used so OCR should not be applied
+ setattr(processor.feature_extractor, "apply_ocr", False)
+
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(
+ batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
+ )
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
+ token_to_add = processor.tokenizer.num_special_tokens_to_add(is_pair)
+ seq_length = compute_effective_axis_dimension(
+ seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
+ )
+ # Generate dummy inputs according to compute batch and sequence
+ dummy_text = [[" ".join([processor.tokenizer.unk_token]) * seq_length]] * batch_size
+
+ # Generate dummy bounding boxes
+ dummy_bboxes = [[[48, 84, 73, 128]]] * batch_size
+
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ # batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
+ dummy_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
+
+ inputs = dict(
+ processor(
+ dummy_image,
+ text=dummy_text,
+ boxes=dummy_bboxes,
+ return_tensors=framework,
+ )
+ )
+
+ return inputs
diff --git a/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
new file mode 100644
index 000000000000..2d771a27903d
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
@@ -0,0 +1,246 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for LayoutLMv3.
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
+
+
+# soft dependency
+if is_pytesseract_available():
+ import pytesseract
+
+logger = logging.get_logger(__name__)
+
+ImageInput = Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+]
+
+
+def normalize_box(box, width, height):
+ return [
+ int(1000 * (box[0] / width)),
+ int(1000 * (box[1] / height)),
+ int(1000 * (box[2] / width)),
+ int(1000 * (box[3] / height)),
+ ]
+
+
+def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
+ # apply OCR
+ data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
+
+ # filter empty words and corresponding coordinates
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
+
+ # turn coordinates into (left, top, left+width, top+height) format
+ actual_boxes = []
+ for x, y, w, h in zip(left, top, width, height):
+ actual_box = [x, y, x + w, y + h]
+ actual_boxes.append(actual_box)
+
+ image_width, image_height = image.size
+
+ # finally, normalize the bounding boxes
+ normalized_boxes = []
+ for box in actual_boxes:
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
+
+ assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
+
+ return words, normalized_boxes
+
+
+class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
+ apply OCR on them in order to get a list of words and normalized bounding boxes.
+
+ This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
+ of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int` or `Tuple(int)`, *optional*, defaults to 224):
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
+ integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
+ set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
+ ocr_lang (`str`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used.
+ tesseract_config (`str`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'.
+
+
+
+ LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
+
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BILINEAR,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ apply_ocr=True,
+ ocr_lang=None,
+ tesseract_config="",
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.apply_ocr = apply_ocr
+ self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
+
+ def __call__(
+ self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ - **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
+ initialized with `apply_ocr` set to `True`).
+ - **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
+ (only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
+
+ Examples:
+
+ ```python
+ >>> from transformers import LayoutLMv3FeatureExtractor
+ >>> from PIL import Image
+
+ >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+
+ >>> # option 1: with apply_ocr=True (default)
+ >>> feature_extractor = LayoutLMv3FeatureExtractor()
+ >>> encoding = feature_extractor(image, return_tensors="pt")
+ >>> print(encoding.keys())
+ >>> # dict_keys(['pixel_values', 'words', 'boxes'])
+
+ >>> # option 2: with apply_ocr=False
+ >>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ >>> encoding = feature_extractor(image, return_tensors="pt")
+ >>> print(encoding.keys())
+ >>> # dict_keys(['pixel_values'])
+ ```"""
+
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
+ f"but is of type {type(images)}."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # Tesseract OCR to get words + normalized bounding boxes
+ if self.apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ # transformations (resizing + normalization)
+ if self.do_resize and self.size is not None:
+ images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if self.apply_ocr:
+ encoded_inputs["words"] = words_batch
+ encoded_inputs["boxes"] = boxes_batch
+
+ return encoded_inputs
diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
new file mode 100644
index 000000000000..f3bdd2cd8d90
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
@@ -0,0 +1,1311 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LayoutLMv3 model."""
+
+import collections
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers import apply_chunking_to_forward
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from ...activations import ACT2FN
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from .configuration_layoutlmv3 import LayoutLMv3Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LayoutLMv3Config"
+
+LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/layoutlmv3-base",
+ "microsoft/layoutlmv3-large",
+ # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
+]
+
+LAYOUTLMV3_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`LayoutLMv2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LAYOUTLMV3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `{0}`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`LayoutLMv2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+
+ attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `{0}`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class LayoutLMv3PatchEmbeddings(nn.Module):
+ """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
+ image sizes."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ image_size = (
+ config.input_size
+ if isinstance(config.input_size, collections.abc.Iterable)
+ else (config.input_size, config.input_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+ self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values, position_embedding=None):
+ embeddings = self.proj(pixel_values)
+
+ if position_embedding is not None:
+ # interpolate the position embedding to the corresponding size
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
+ position_embedding = position_embedding.permute(0, 3, 1, 2)
+ patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+ position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
+ embeddings = embeddings + position_embedding
+
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class LayoutLMv3TextEmbeddings(nn.Module):
+ """
+ LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+
+ def calculate_spatial_position_embeddings(self, bbox):
+ try:
+ left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
+ upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
+ right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
+ lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
+ except IndexError as e:
+ raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
+
+ h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
+ w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
+
+ # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
+ spatial_position_embeddings = torch.cat(
+ [
+ left_position_embeddings,
+ upper_position_embeddings,
+ right_position_embeddings,
+ lower_position_embeddings,
+ h_position_embeddings,
+ w_position_embeddings,
+ ],
+ dim=-1,
+ )
+ return spatial_position_embeddings
+
+ def create_position_ids_from_input_ids(self, input_ids, padding_idx):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
+ return incremental_indices.long() + padding_idx
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
+ input_ids.device
+ )
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
+
+ embeddings = embeddings + spatial_position_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class LayoutLMv3PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LayoutLMv3Config
+ base_model_prefix = "layoutlmv3"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class LayoutLMv3SelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def cogview_attention(self, attention_scores, alpha=32):
+ """
+ https://arxiv.org/abs/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
+ (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
+ will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
+ cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
+ """
+ scaled_attention_scores = attention_scores / alpha
+ max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
+ new_attention_scores = (scaled_attention_scores - max_value) * alpha
+ return nn.Softmax(dim=-1)(new_attention_scores)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # The attention scores QT K/ād could be significantly larger than input elements, and result in overflow.
+ # Changing the computational order into QT(K/ād) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
+ attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
+
+ if self.has_relative_attention_bias and self.has_spatial_attention_bias:
+ attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
+ elif self.has_relative_attention_bias:
+ attention_scores += rel_pos / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ # Use the trick of the CogView paper to stablize training
+ attention_probs = self.cogview_attention(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
+class LayoutLMv3SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
+class LayoutLMv3Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = LayoutLMv3SelfAttention(config)
+ self.output = LayoutLMv3SelfOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
+class LayoutLMv3Layer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = LayoutLMv3Attention(config)
+ self.intermediate = LayoutLMv3Intermediate(config)
+ self.output = LayoutLMv3Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class LayoutLMv3Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ if self.has_relative_attention_bias:
+ self.rel_pos_bins = config.rel_pos_bins
+ self.max_rel_pos = config.max_rel_pos
+ self.rel_pos_onehot_size = config.rel_pos_bins
+ self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
+
+ if self.has_spatial_attention_bias:
+ self.max_rel_2d_pos = config.max_rel_2d_pos
+ self.rel_2d_pos_bins = config.rel_2d_pos_bins
+ self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
+ self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+ self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+
+ def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ ret = 0
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).long() * num_buckets
+ n = torch.abs(relative_position)
+ else:
+ n = torch.max(-relative_position, torch.zeros_like(relative_position))
+ # now n is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).to(torch.long)
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def _cal_1d_pos_emb(self, hidden_states, position_ids):
+ rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
+
+ rel_pos = self.relative_position_bucket(
+ rel_pos_mat,
+ num_buckets=self.rel_pos_bins,
+ max_distance=self.max_rel_pos,
+ )
+ rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
+ rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
+ rel_pos = rel_pos.contiguous()
+ return rel_pos
+
+ def _cal_2d_pos_emb(self, hidden_states, bbox):
+ position_coord_x = bbox[:, :, 0]
+ position_coord_y = bbox[:, :, 3]
+ rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
+ rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
+ rel_pos_x = self.relative_position_bucket(
+ rel_pos_x_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_y = self.relative_position_bucket(
+ rel_pos_y_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+ rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+ rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
+ rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
+ rel_pos_x = rel_pos_x.contiguous()
+ rel_pos_y = rel_pos_y.contiguous()
+ rel_2d_pos = rel_pos_x + rel_pos_y
+ return rel_2d_pos
+
+ def forward(
+ self,
+ hidden_states,
+ bbox=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ position_ids=None,
+ patch_height=None,
+ patch_width=None,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
+ rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
+ # The above line will cause error:
+ # RuntimeError: Trying to backward through the graph a second time
+ # (or directly access saved tensors after they have already been freed).
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ rel_pos,
+ rel_2d_pos,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
+class LayoutLMv3Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
+class LayoutLMv3Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+@add_start_docstrings(
+ "The bare LayoutLMv3 Model transformer outputting raw hidden-states without any specific head on top.",
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ if config.text_embed:
+ self.embeddings = LayoutLMv3TextEmbeddings(config)
+
+ if config.visual_embed:
+ # use the default pre-training parameters for fine-tuning (e.g., input_size)
+ # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
+ self.patch_embed = LayoutLMv3PatchEmbeddings(config)
+
+ size = int(config.input_size / config.patch_size)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ self.init_visual_bbox(image_size=(size, size))
+
+ self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
+
+ self.encoder = LayoutLMv3Encoder(config)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
+ """
+ Create the bounding boxes for the visual (patch) tokens.
+ """
+ visual_bbox_x = torch.div(
+ torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
+ )
+ visual_bbox_y = torch.div(
+ torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
+ )
+ visual_bbox = torch.stack(
+ [
+ visual_bbox_x[:-1].repeat(image_size[0], 1),
+ visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
+ visual_bbox_x[1:].repeat(image_size[0], 1),
+ visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
+ ],
+ dim=-1,
+ ).view(-1, 4)
+
+ cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
+ self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
+
+ def calculate_visual_bbox(self, device, dtype, batch_size):
+ visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
+ visual_bbox = visual_bbox.to(device).type(dtype)
+ return visual_bbox
+
+ def forward_image(self, pixel_values):
+ embeddings = self.patch_embed(pixel_values)
+
+ # add [CLS] token
+ batch_size, seq_len, _ = embeddings.size()
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add position embeddings
+ if self.pos_embed is not None:
+ embeddings = embeddings + self.pos_embed
+
+ embeddings = self.pos_drop(embeddings)
+ embeddings = self.norm(embeddings)
+
+ return embeddings
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif pixel_values is not None:
+ batch_size = len(pixel_values)
+ device = pixel_values.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
+
+ if input_ids is not None or inputs_embeds is not None:
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+ if bbox is None:
+ bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ final_bbox = final_position_ids = None
+ patch_height = patch_width = None
+ if pixel_values is not None:
+ patch_height, patch_width = int(pixel_values.shape[2] / self.config.patch_size), int(
+ pixel_values.shape[3] / self.config.patch_size
+ )
+ visual_embeddings = self.forward_image(pixel_values)
+ visual_attention_mask = torch.ones(
+ (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
+ else:
+ attention_mask = visual_attention_mask
+
+ if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ if self.config.has_spatial_attention_bias:
+ visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
+ if bbox is not None:
+ final_bbox = torch.cat([bbox, visual_bbox], dim=1)
+ else:
+ final_bbox = visual_bbox
+
+ visual_position_ids = torch.arange(
+ 0, visual_embeddings.shape[1], dtype=torch.long, device=device
+ ).repeat(batch_size, 1)
+ if input_ids is not None or inputs_embeds is not None:
+ position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
+ position_ids = position_ids.expand(input_shape)
+ final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
+ else:
+ final_position_ids = visual_position_ids
+
+ if input_ids is not None or inputs_embeds is not None:
+ embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
+ else:
+ embedding_output = visual_embeddings
+
+ embedding_output = self.LayerNorm(embedding_output)
+ embedding_output = self.dropout(embedding_output)
+ elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ if self.config.has_spatial_attention_bias:
+ final_bbox = bbox
+ if self.config.has_relative_attention_bias:
+ position_ids = self.embeddings.position_ids[:, : input_shape[1]]
+ position_ids = position_ids.expand_as(input_ids)
+ final_position_ids = position_ids
+
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, None, device, dtype=embedding_output.dtype
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ bbox=final_bbox,
+ position_ids=final_position_ids,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ patch_height=patch_height,
+ patch_width=patch_width,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class LayoutLMv3ClassificationHead(nn.Module):
+ """
+ Head for sentence-level classification tasks. Reference: RobertaClassificationHead
+ """
+
+ def __init__(self, config, pool_feature=False):
+ super().__init__()
+ self.pool_feature = pool_feature
+ if pool_feature:
+ self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
+ else:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
+ for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
+ [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
+ [Kleister-NDA](https://github.com/applicaai/kleister-nda).
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if config.num_labels < 10:
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ else:
+ self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ pixel_values=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForTokenClassification
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+ >>> word_labels = example["ner_tags"]
+
+ >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ pixel_values=pixel_values,
+ )
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ # only take the text part of the output representations
+ sequence_output = outputs[0][:, :seq_length]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a span classification head on top for extractive question-answering tasks such as
+ [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
+ compute `span start logits` and `span end logits`).
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ bbox=None,
+ pixel_values=None,
+ ):
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
+ >>> from datasets import load_dataset
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> question = "what's his name?"
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
+ >>> start_positions = torch.tensor([1])
+ >>> end_positions = torch.tensor([3])
+
+ >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
+ >>> loss = outputs.loss
+ >>> start_scores = outputs.start_logits
+ >>> end_scores = outputs.end_logits
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
+ [CLS] token) e.g. for document image classification tasks such as the
+ [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ bbox=None,
+ pixel_values=None,
+ ):
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
+ >>> from datasets import load_dataset
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
+ >>> sequence_label = torch.tensor([1])
+
+ >>> outputs = model(**encoding, labels=sequence_label)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ )
+
+ sequence_output = outputs[0][:, 0, :]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py
new file mode 100644
index 000000000000..c80b2bd5f203
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py
@@ -0,0 +1,158 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for LayoutLMv3.
+"""
+from typing import List, Optional, Union
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from ...utils import TensorType
+
+
+class LayoutLMv3Processor(ProcessorMixin):
+ r"""
+ Constructs a LayoutLMv3 processor which combines a LayoutLMv3 feature extractor and a LayoutLMv3 tokenizer into a
+ single processor.
+
+ [`LayoutLMv3Processor`] offers all the functionalities you need to prepare data for the model.
+
+ It first uses [`LayoutLMv3FeatureExtractor`] to resize and normalize document images, and optionally applies OCR to
+ get words and normalized bounding boxes. These are then provided to [`LayoutLMv3Tokenizer`] or
+ [`LayoutLMv3TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`,
+ `attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned
+ into token-level `labels` for token classification tasks (such as FUNSD, CORD).
+
+ Args:
+ feature_extractor (`LayoutLMv3FeatureExtractor`):
+ An instance of [`LayoutLMv3FeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`):
+ An instance of [`LayoutLMv3Tokenizer`] or [`LayoutLMv3TokenizerFast`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "LayoutLMv3FeatureExtractor"
+ tokenizer_class = ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast")
+
+ def __call__(
+ self,
+ images,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ This method first forwards the `images` argument to [`~LayoutLMv3FeatureExtractor.__call__`]. In case
+ [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
+ bounding boxes along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output,
+ together with resized and normalized `pixel_values`. In case [`LayoutLMv3FeatureExtractor`] was initialized
+ with `apply_ocr` set to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user
+ along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output, together with
+ resized and normalized `pixel_values`.
+
+ Please refer to the docstring of the above two methods for more information.
+ """
+ # verify input
+ if self.feature_extractor.apply_ocr and (boxes is not None):
+ raise ValueError(
+ "You cannot provide bounding boxes "
+ "if you initialized the feature extractor with apply_ocr set to True."
+ )
+
+ if self.feature_extractor.apply_ocr and (word_labels is not None):
+ raise ValueError(
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
+ )
+
+ # first, apply the feature extractor
+ features = self.feature_extractor(images=images, return_tensors=return_tensors)
+
+ # second, apply the tokenizer
+ if text is not None and self.feature_extractor.apply_ocr and text_pair is None:
+ if isinstance(text, str):
+ text = [text] # add batch dimension (as the feature extractor always adds a batch dimension)
+ text_pair = features["words"]
+
+ encoded_inputs = self.tokenizer(
+ text=text if text is not None else features["words"],
+ text_pair=text_pair if text_pair is not None else None,
+ boxes=boxes if boxes is not None else features["boxes"],
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ # add pixel values
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["pixel_values"] = images
+
+ return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
new file mode 100644
index 000000000000..b01e70ffb037
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
@@ -0,0 +1,1478 @@
+# coding=utf-8
+# Copyright The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for LayoutLMv3. Same as LayoutLMv2, but RoBERTa-like BPE tokenization instead of WordPiece."""
+
+import json
+import os
+from functools import lru_cache
+from typing import Dict, List, Optional, Tuple, Union
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PreTokenizedInput,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json",
+ },
+ "merges_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "microsoft/layoutlmv3-base": 512,
+ "microsoft/layoutlmv3-large": 512,
+}
+
+
+LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+
+LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
+ `None`, this will use the predefined model maximum length if a maximum length is required by one of the
+ truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
+ truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+
+@lru_cache()
+# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("”"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+# Copied from transformers.models.roberta.tokenization_roberta.get_pairs
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class LayoutLMv3Tokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a LayoutLMv3 tokenizer. Based on [`RoBERTatokenizer`] (Byte Pair Encoding or BPE).
+ [`LayoutLMv3Tokenizer`] can be used to turn words, word-level bounding boxes and optional word labels to
+ token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`, and optional `labels` (for token
+ classification).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ [`LayoutLMv3Tokenizer`] runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the
+ word-level bounding boxes into token-level bounding boxes.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
+ cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ """
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask", "bbox"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=True,
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[0, 0, 0, 0],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ **kwargs
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ **kwargs,
+ )
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @property
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size
+ def vocab_size(self):
+ return len(self.encoder)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A RoBERTa sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ # If the text starts with a token that should not be split, no space is added before the text in any case.
+ # It's necessary to match the fast tokenization
+ if (
+ (is_split_into_words or add_prefix_space)
+ and (len(text) > 0 and not text[0].isspace())
+ and sum([text.startswith(no_split_token) for no_split_token in self.unique_no_split_tokens]) == 0
+ ):
+ text = " " + text
+ return (text, kwargs)
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.__call__
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.batch_encode_plus
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_encode_plus
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ batch_outputs = self._batch_prepare_for_model(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_prepare_for_model
+ def _batch_prepare_for_model(
+ self,
+ batch_text_or_text_pairs,
+ is_pair: bool = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens.
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)):
+ batch_text_or_text_pair, boxes_example = example
+ outputs = self.prepare_for_model(
+ batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair,
+ batch_text_or_text_pair[1] if is_pair else None,
+ boxes_example,
+ word_labels=word_labels[idx] if word_labels is not None else None,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode
+ def encode(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> List[int]:
+ encoded_inputs = self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode_plus
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._encode_plus
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ return self.prepare_for_model(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
+ truncates sequences if overflowing while taking into account the special tokens and manages a moving window
+ (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and
+ *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a
+ combination of arguments will raise an error.
+
+ Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into
+ token-level `labels`. The word label is used for the first token of the word, while remaining tokens are
+ labeled with -100, such that they will be ignored by the loss function.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ tokens = []
+ pair_tokens = []
+ token_boxes = []
+ pair_token_boxes = []
+ labels = []
+
+ if text_pair is None:
+ if word_labels is None:
+ # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference)
+ for word, box in zip(text, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ else:
+ # CASE 2: token classification (training)
+ for word, box, label in zip(text, boxes, word_labels):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ if self.only_label_first_subword:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1))
+ else:
+ labels.extend([label] * len(word_tokens))
+ else:
+ # CASE 3: document visual question answering (inference)
+ # text = question
+ # text_pair = words
+ tokens = self.tokenize(text)
+ token_boxes = [self.pad_token_box for _ in range(len(tokens))]
+
+ for word, box in zip(text_pair, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ pair_tokens.extend(word_tokens)
+ pair_token_boxes.extend([box] * len(word_tokens))
+
+ # Create ids + pair_ids
+ ids = self.convert_tokens_to_ids(tokens)
+ pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
+
+ if (
+ return_overflowing_tokens
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
+ and pair_ids is not None
+ ):
+ raise ValueError(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+
+ # Compute the total size of the returned encodings
+ pair = bool(pair_ids is not None)
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ ) = self.truncate_sequences(
+ ids,
+ token_boxes,
+ pair_ids=pair_ids,
+ pair_token_boxes=pair_token_boxes,
+ labels=labels,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+
+ if return_token_type_ids and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes
+ encoded_inputs["overflowing_labels"] = overflowing_labels
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box]
+ if pair_token_boxes:
+ pair_token_boxes = [self.sep_token_box] + pair_token_boxes + [self.sep_token_box]
+ token_boxes = token_boxes + pair_token_boxes if pair else token_boxes
+ if labels:
+ labels = [self.pad_token_label] + labels + [self.pad_token_label]
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+ token_boxes = token_boxes + pair_token_boxes if pair else token_boxes
+
+ # Build output dictionary
+ encoded_inputs["input_ids"] = sequence
+ encoded_inputs["bbox"] = token_boxes
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ if labels:
+ encoded_inputs["labels"] = labels
+
+ # Check lengths
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.truncate_sequences
+ def truncate_sequences(
+ self,
+ ids: List[int],
+ token_boxes: List[List[int]],
+ pair_ids: Optional[List[int]] = None,
+ pair_token_boxes: Optional[List[List[int]]] = None,
+ labels: Optional[List[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> Tuple[List[int], List[int], List[int]]:
+ """
+ Truncates a sequence pair in-place following the strategy.
+
+ Args:
+ ids (`List[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ token_boxes (`List[List[int]]`):
+ Bounding boxes of the first sequence.
+ pair_ids (`List[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ pair_token_boxes (`List[List[int]]`, *optional*):
+ Bounding boxes of the second sequence.
+ labels (`List[int]`, *optional*):
+ Labels of the first sequence (for token classification tasks).
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
+ Number of tokens to remove using the truncation strategy.
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ The strategy to follow for truncation. Can be:
+
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+ batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+ than the model maximum admissible input size).
+ stride (`int`, *optional*, defaults to 0):
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+ sequence returned. The value of this argument defines the number of additional tokens.
+
+ Returns:
+ `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
+ of sequences (or a batch of pairs) is provided.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
+ ):
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ overflowing_tokens = ids[-window_len:]
+ overflowing_token_boxes = token_boxes[-window_len:]
+ overflowing_labels = labels[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ token_boxes = token_boxes[:-num_tokens_to_remove]
+ labels = labels[:-num_tokens_to_remove]
+ else:
+ error_msg = (
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the first sequence has a length {len(ids)}. "
+ )
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ error_msg = (
+ error_msg
+ + "Please select another truncation strategy than "
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
+ )
+ logger.error(error_msg)
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ logger.warning(
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
+ )
+ for _ in range(num_tokens_to_remove):
+ if pair_ids is None or len(ids) > len(pair_ids):
+ ids = ids[:-1]
+ token_boxes = token_boxes[:-1]
+ labels = labels[:-1]
+ else:
+ pair_ids = pair_ids[:-1]
+ pair_token_boxes = pair_token_boxes[:-1]
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ overflowing_tokens = pair_ids[-window_len:]
+ overflowing_token_boxes = pair_token_boxes[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_first'."
+ )
+
+ return (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._pad
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return encoded_inputs
diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py
new file mode 100644
index 000000000000..be5f938dbf17
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py
@@ -0,0 +1,853 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fast tokenization class for LayoutLMv3. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus
+and _encode_plus, in which the Rust tokenizer is used.
+"""
+
+import json
+from typing import Dict, List, Optional, Tuple, Union
+
+from tokenizers import pre_tokenizers, processors
+
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PaddingStrategy,
+ PreTokenizedInput,
+ TensorType,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import add_end_docstrings, logging
+from .tokenization_layoutlmv3 import (
+ LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING,
+ LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+ LayoutLMv3Tokenizer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json",
+ },
+ "merges_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "microsoft/layoutlmv3-base": 512,
+ "microsoft/layoutlmv3-large": 512,
+}
+
+
+class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" LayoutLMv3 tokenizer (backed by HuggingFace's *tokenizers* library). Based on BPE.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether the post processing step should trim offsets to avoid including whitespaces.
+ cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = LayoutLMv3Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=True,
+ trim_offsets=True,
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[0, 0, 0, 0],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ tokenizer_component = "post_processor"
+ tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
+ if tokenizer_component_instance:
+ state = json.loads(tokenizer_component_instance.__getstate__())
+
+ # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
+ if "sep" in state:
+ state["sep"] = tuple(state["sep"])
+ if "cls" in state:
+ state["cls"] = tuple(state["cls"])
+
+ changes_to_apply = False
+
+ if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ state["add_prefix_space"] = add_prefix_space
+ changes_to_apply = True
+
+ if state.get("trim_offsets", trim_offsets) != trim_offsets:
+ state["trim_offsets"] = trim_offsets
+ changes_to_apply = True
+
+ if changes_to_apply:
+ component_class = getattr(processors, state.pop("type"))
+ new_value = component_class(**state)
+ setattr(self.backend_tokenizer, tokenizer_component, new_value)
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.__call__
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.batch_encode_plus
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.tokenize
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
+ batched_input = [(text, pair)] if pair else [text]
+ encodings = self._tokenizer.encode_batch(
+ batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
+ )
+
+ return encodings[0].tokens
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.encode_plus
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._batch_encode_plus with LayoutLMv2->LayoutLMv3
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+
+ if not isinstance(batch_text_or_text_pairs, list):
+ raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
+
+ # Set the truncation and padding strategy and restore the initial configuration
+ self.set_truncation_and_padding(
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ )
+
+ if is_pair:
+ batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
+
+ encodings = self._tokenizer.encode_batch(
+ batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ is_pretokenized=True, # we set this to True as LayoutLMv3 always expects pretokenized inputs
+ )
+
+ # Convert encoding to dict
+ # `Tokens` has type: Tuple[
+ # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
+ # List[EncodingFast]
+ # ]
+ # with nested dimensions corresponding to batch, overflows, sequence length
+ tokens_and_encodings = [
+ self._convert_encoding(
+ encoding=encoding,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=True
+ if word_labels is not None
+ else return_offsets_mapping, # we use offsets to create the labels
+ return_length=return_length,
+ verbose=verbose,
+ )
+ for encoding in encodings
+ ]
+
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+ # (we say ~ because the number of overflow varies with the example in the batch)
+ #
+ # To match each overflowing sample with the original sample in the batch
+ # we add an overflow_to_sample_mapping array (see below)
+ sanitized_tokens = {}
+ for key in tokens_and_encodings[0][0].keys():
+ stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+ sanitized_tokens[key] = stack
+ sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+ # If returning overflowing tokens, we need to return a mapping
+ # from the batch idx to the original sample
+ if return_overflowing_tokens:
+ overflow_to_sample_mapping = []
+ for i, (toks, _) in enumerate(tokens_and_encodings):
+ overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+ for input_ids in sanitized_tokens["input_ids"]:
+ self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+
+ # create the token boxes
+ token_boxes = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ token_boxes_example = []
+ for id, sequence_id, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_encodings[batch_index].sequence_ids,
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if is_pair and sequence_id == 0:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ token_boxes_example.append(boxes[original_index][word_id])
+ else:
+ if id == self.cls_token_id:
+ token_boxes_example.append(self.cls_token_box)
+ elif id == self.sep_token_id:
+ token_boxes_example.append(self.sep_token_box)
+ elif id == self.pad_token_id:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ raise ValueError("Id not recognized")
+ token_boxes.append(token_boxes_example)
+
+ sanitized_tokens["bbox"] = token_boxes
+
+ # optionally, create the labels
+ if word_labels is not None:
+ labels = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ labels_example = []
+ for id, offset, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_tokens["offset_mapping"][batch_index],
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if self.only_label_first_subword:
+ if offset[0] == 0:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ else:
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ labels.append(labels_example)
+
+ sanitized_tokens["labels"] = labels
+ # finally, remove offsets if the user didn't want them
+ if not return_offsets_mapping:
+ del sanitized_tokens["offset_mapping"]
+
+ return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._encode_plus
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # make it a batched input
+ # 2 options:
+ # 1) only text, in case text must be a list of str
+ # 2) text + text_pair, in which case text = str and text_pair a list of str
+ batched_input = [(text, text_pair)] if text_pair else [text]
+ batched_boxes = [boxes]
+ batched_word_labels = [word_labels] if word_labels is not None else None
+ batched_output = self._batch_encode_plus(
+ batched_input,
+ is_pair=bool(text_pair is not None),
+ boxes=batched_boxes,
+ word_labels=batched_word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Return tensor is None, then we can remove the leading batch axis
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
+ if return_tensors is None and not return_overflowing_tokens:
+ batched_output = BatchEncoding(
+ {
+ key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
+ for key, value in batched_output.items()
+ },
+ batched_output.encodings,
+ )
+
+ self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+ return batched_output
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._pad
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return encoded_inputs
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Args:
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not:
+ make use of token type ids, therefore a list of zeros is returned.
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
diff --git a/src/transformers/models/layoutxlm/__init__.py b/src/transformers/models/layoutxlm/__init__.py
index c9459aff2033..9c09d75d68ba 100644
--- a/src/transformers/models/layoutxlm/__init__.py
+++ b/src/transformers/models/layoutxlm/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
@@ -27,27 +28,43 @@
)
-_import_structure = {}
+_import_structure = {"processing_layoutxlm": ["LayoutXLMProcessor"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutxlm"] = ["LayoutXLMTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutxlm_fast"] = ["LayoutXLMTokenizerFast"]
-if is_vision_available():
- _import_structure["processing_layoutxlm"] = ["LayoutXLMProcessor"]
-
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ from .processing_layoutxlm import LayoutXLMProcessor
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutxlm import LayoutXLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutxlm_fast import LayoutXLMTokenizerFast
- if is_vision_available():
- from .processing_layoutlmv2 import LayoutXLMProcessor
-
else:
import sys
diff --git a/src/transformers/models/layoutxlm/processing_layoutxlm.py b/src/transformers/models/layoutxlm/processing_layoutxlm.py
index 99245ccc1776..03423d17c27b 100644
--- a/src/transformers/models/layoutxlm/processing_layoutxlm.py
+++ b/src/transformers/models/layoutxlm/processing_layoutxlm.py
@@ -86,8 +86,7 @@ def __call__(
if self.feature_extractor.apply_ocr and (word_labels is not None):
raise ValueError(
- "You cannot provide word labels "
- "if you initialized the feature extractor with apply_ocr set to True."
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
)
# first, apply the feature extractor
@@ -122,6 +121,37 @@ def __call__(
)
# add pixel values
- encoded_inputs["image"] = features.pop("pixel_values")
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["image"] = images
return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
index 8fded392844d..52d9b3ba802d 100644
--- a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
+++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
@@ -20,11 +20,9 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
-from transformers.models.layoutlmv2.tokenization_layoutlmv2 import LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import (
- ENCODE_KWARGS_DOCSTRING,
BatchEncoding,
EncodedInput,
PreTokenizedInput,
@@ -44,6 +42,110 @@
logger = logging.get_logger(__name__)
+LAYOUTXLM_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **bbox** -- List of bounding boxes to be fed to a model.
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified).
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`).
+"""
+
+
class LayoutXLMTokenizer(PreTrainedTokenizer):
"""
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
@@ -339,7 +441,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
return (out_vocab_file,)
- @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
@@ -438,7 +540,8 @@ def _is_valid_text_input(t):
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
@@ -542,7 +645,7 @@ def _batch_encode_plus(
return BatchEncoding(batch_outputs)
- @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
def _batch_prepare_for_model(
self,
batch_text_or_text_pairs,
@@ -665,7 +768,7 @@ def _encode_plus(
verbose=verbose,
)
- @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
def prepare_for_model(
self,
text: Union[TextInput, PreTokenizedInput],
@@ -960,7 +1063,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_second'."
+ "for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
@@ -974,7 +1077,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (
diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
index 35b438387747..71a76614376a 100644
--- a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
+++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
@@ -19,11 +19,8 @@
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
-from transformers.models.layoutlmv2.tokenization_layoutlmv2 import LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING
-
from ...tokenization_utils import AddedToken
from ...tokenization_utils_base import (
- ENCODE_KWARGS_DOCSTRING,
BatchEncoding,
EncodedInput,
PreTokenizedInput,
@@ -48,6 +45,109 @@
logger = logging.get_logger(__name__)
+LAYOUTXLM_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **bbox** -- List of bounding boxes to be fed to a model.
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified).
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`).
+"""
+
class LayoutXLMTokenizerFast(PreTrainedTokenizerFast):
"""
@@ -166,7 +266,7 @@ def __init__(
self.pad_token_label = pad_token_label
self.only_label_first_subword = only_label_first_subword
- @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
@@ -265,7 +365,8 @@ def _is_valid_text_input(t):
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
diff --git a/src/transformers/models/led/__init__.py b/src/transformers/models/led/__init__.py
index d60800f981a5..da871828ad88 100644
--- a/src/transformers/models/led/__init__.py
+++ b/src/transformers/models/led/__init__.py
@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -25,10 +31,20 @@
"tokenization_led": ["LEDTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_led"] = [
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
"LEDForConditionalGeneration",
@@ -39,7 +55,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]
@@ -47,10 +68,20 @@
from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from .tokenization_led import LEDTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_led_fast import LEDTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
@@ -60,7 +91,12 @@
LEDPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
else:
diff --git a/src/transformers/models/led/configuration_led.py b/src/transformers/models/led/configuration_led.py
index 5f534ab28703..37720c730af1 100644
--- a/src/transformers/models/led/configuration_led.py
+++ b/src/transformers/models/led/configuration_led.py
@@ -86,18 +86,17 @@ class LEDConfig(PretrainedConfig):
Example:
```python
+ >>> from transformers import LEDModel, LEDConfig
- ```
+ >>> # Initializing a LED allenai/led-base-16384 style configuration
+ >>> configuration = LEDConfig()
- >>> from transformers import LEDModel, LEDConfig
+ >>> # Initializing a model from the allenai/led-base-16384 style configuration
+ >>> model = LEDModel(configuration)
- >>> # Initializing a LED allenai/led-base-16384 style configuration >>> configuration = LEDConfig()
-
- >>> # Initializing a model from the allenai/led-base-16384 style configuration >>> model =
- LEDModel(configuration)
-
- >>> # Accessing the model configuration >>> configuration = model.config
- """
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
model_type = "led"
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py
index 3e852cf2a67d..0837ac2bc423 100755
--- a/src/transformers/models/led/modeling_led.py
+++ b/src/transformers/models/led/modeling_led.py
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -207,7 +207,7 @@ def forward(
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
- remove_from_windowed_attention_mask, -10000.0
+ remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
@@ -222,7 +222,10 @@ def forward(
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
- ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ ], (
+ f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ )
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
@@ -576,7 +579,7 @@ def _concat_with_global_key_attn_probs(
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
- ] = -10000.0
+ ] = torch.finfo(attn_probs_from_global_key.dtype).min
return attn_probs_from_global_key
@@ -662,17 +665,21 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
- ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
+ ], (
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {global_attn_scores.size()}."
+ )
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
- ] = -10000.0
+ ] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
- -10000.0,
+ torch.finfo(global_attn_scores.dtype).min,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
@@ -705,7 +712,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
- ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
+ ], (
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {global_attn_output.size()}."
+ )
global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_output = global_attn_output.view(
@@ -766,7 +777,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
@@ -837,7 +849,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -852,7 +865,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -873,7 +887,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = (
@@ -1007,7 +1022,7 @@ def forward(
"""
residual = hidden_states
- # Self Attention
+ # Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
@@ -1437,13 +1452,11 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
LED_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
+ This model inherits from [`PreTrainedModel`]. See the superclass documentation for the generic methods the library
+ implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for general usage and behavior.
Parameters:
config ([`LEDConfig`]):
@@ -1595,7 +1608,7 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
class LEDEncoder(LEDPreTrainedModel):
"""
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`LEDEncoderLayer`].
Args:
@@ -1643,7 +1656,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non
self.post_init()
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
- # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
+ # longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
@@ -1815,7 +1828,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -1995,8 +2009,8 @@ def forward(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
- of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
@@ -2071,7 +2085,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -2283,9 +2298,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
base_model_prefix = "led"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: LEDConfig):
@@ -2426,6 +2441,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
+ global_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -2443,6 +2459,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
+ "global_attention_mask": global_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py
index a882e32ec4e7..7ff69c2a634a 100644
--- a/src/transformers/models/led/modeling_tf_led.py
+++ b/src/transformers/models/led/modeling_tf_led.py
@@ -74,11 +74,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
return shifted_input_ids
+# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
- bsz, tgt_len = input_ids_shape
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@@ -90,7 +92,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -246,7 +249,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
- message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
+ message=(
+ f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
+ ),
)
# compute global attn indices required through out forward fn
@@ -299,7 +305,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@@ -392,7 +401,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
- message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
+ message=(
+ f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
+ f" {shape_list(key)}"
+ ),
)
chunks_count = seq_len // window_overlap - 1
@@ -677,7 +689,10 @@ def _chunk(hidden_states, window_overlap):
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
- message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
+ message=(
+ "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
+ f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
+ ),
)
chunked_hidden_states = tf.reshape(
@@ -855,7 +870,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
- message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
+ message=(
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {shape_list(global_attn_scores)}."
+ ),
)
global_attn_scores = tf.reshape(
@@ -894,7 +913,10 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -913,7 +935,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
- message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
+ message=(
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {shape_list(global_attn_output)}."
+ ),
)
global_attn_output = tf.reshape(
@@ -1069,7 +1095,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -1077,7 +1106,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
@@ -1092,7 +1124,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -1108,7 +1143,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -1238,7 +1276,7 @@ def call(
"""
residual = hidden_states
- # Self Attention
+ # Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
@@ -1612,7 +1650,7 @@ class TFLEDSeq2SeqLMOutput(ModelOutput):
class TFLEDEncoder(tf.keras.layers.Layer):
config_class = LEDConfig
"""
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`TFLEDEncoderLayer`].
Args:
@@ -1753,7 +1791,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -1950,7 +1991,7 @@ def call(
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of shape
- `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`.
+ `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
@@ -2013,7 +2054,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -2286,10 +2330,12 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.led = TFLEDMainLayer(config, name="led")
self.use_cache = config.use_cache
- # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
+ # TODO (Joao): investigate why LED has numerical issues in XLA generate
+ self.supports_xla_generation = False
def get_decoder(self):
return self.led.decoder
@@ -2459,11 +2505,19 @@ def _reorder_cache(past, beam_idx):
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
- from_logits=True,
- reduction=tf.keras.losses.Reduction.NONE,
- )
- melted_labels = tf.reshape(labels, (-1,))
- active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
- reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
- labels = tf.boolean_mask(melted_labels, active_loss)
- return loss_fn(labels, reduced_logits)
+ from_logits=True, reduction=tf.keras.losses.Reduction.NONE
+ )
+ if self.config.tf_legacy_loss:
+ melted_labels = tf.reshape(labels, (-1,))
+ active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
+ reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
+ labels = tf.boolean_mask(melted_labels, active_loss)
+ return loss_fn(labels, reduced_logits)
+
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
+ # make sure only non-padding labels affect the loss
+ loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
+ masked_loss = unmasked_loss * loss_mask
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
+ return tf.reshape(reduced_masked_loss, (1,))
diff --git a/src/transformers/models/levit/__init__.py b/src/transformers/models/levit/__init__.py
new file mode 100644
index 000000000000..ea848f12a2c7
--- /dev/null
+++ b/src/transformers/models/levit/__init__.py
@@ -0,0 +1,75 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_levit"] = ["LevitFeatureExtractor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_levit"] = [
+ "LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LevitForImageClassification",
+ "LevitForImageClassificationWithTeacher",
+ "LevitModel",
+ "LevitPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_levit import LevitFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_levit import (
+ LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ LevitPreTrainedModel,
+ )
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/levit/configuration_levit.py b/src/transformers/models/levit/configuration_levit.py
new file mode 100644
index 000000000000..a1113d7a7512
--- /dev/null
+++ b/src/transformers/models/levit/configuration_levit.py
@@ -0,0 +1,146 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LeViT model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/levit-128S": "https://huggingface.co/facebook/levit-128S/resolve/main/config.json",
+ # See all LeViT models at https://huggingface.co/models?filter=levit
+}
+
+
+class LevitConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LevitModel`]. It is used to instantiate a LeViT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LeViT
+ [facebook/levit-base-192](https://huggingface.co/facebook/levit-base-192) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size of the input image.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the initial convolution layers of patch embedding.
+ stride (`int`, *optional*, defaults to 2):
+ The stride size for the initial convolution layers of patch embedding.
+ padding (`int`, *optional*, defaults to 1):
+ The padding size for the initial convolution layers of patch embedding.
+ patch_size (`int`, *optional*, defaults to 16):
+ The patch size for embeddings.
+ hidden_sizes (`List[int]`, *optional*, defaults to `[128, 256, 384]`):
+ Dimension of each of the encoder blocks.
+ num_attention_heads (`List[int]`, *optional*, defaults to `[4, 8, 12]`):
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
+ depths (`List[int]`, *optional*, defaults to `[4, 4, 4]`):
+ The number of layers in each encoder block.
+ key_dim (`List[int]`, *optional*, defaults to `[16, 16, 16]`):
+ The size of key in each of the encoder blocks.
+ drop_path_rate (`int`, *optional*, defaults to 0):
+ The dropout probability for stochastic depths, used in the blocks of the Transformer encoder.
+ mlp_ratios (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
+ encoder blocks.
+ attention_ratios (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Ratio of the size of the output dimension compared to input dimension of attention layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import LevitModel, LevitConfig
+
+ >>> # Initializing a LeViT levit-base-192 style configuration
+ >>> configuration = LevitConfig()
+
+ >>> # Initializing a model from the levit-base-192 style configuration
+ >>> model = LevitModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "levit"
+
+ def __init__(
+ self,
+ image_size=224,
+ num_channels=3,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ patch_size=16,
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 8, 12],
+ depths=[4, 4, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ mlp_ratio=[2, 2, 2],
+ attention_ratio=[2, 2, 2],
+ initializer_range=0.02,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.hidden_sizes = hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.depths = depths
+ self.key_dim = key_dim
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size
+ self.attention_ratio = attention_ratio
+ self.mlp_ratio = mlp_ratio
+ self.initializer_range = initializer_range
+ self.down_ops = [
+ ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
+ ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
+ ]
+
+
+# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
+class LevitOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
diff --git a/src/transformers/models/levit/convert_levit_timm_to_pytorch.py b/src/transformers/models/levit/convert_levit_timm_to_pytorch.py
new file mode 100644
index 000000000000..d9449aad7ab1
--- /dev/null
+++ b/src/transformers/models/levit/convert_levit_timm_to_pytorch.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert LeViT checkpoints from timm."""
+
+
+import argparse
+import json
+from collections import OrderedDict
+from functools import partial
+from pathlib import Path
+
+import torch
+
+import timm
+from huggingface_hub import hf_hub_download
+from transformers import LevitConfig, LevitFeatureExtractor, LevitForImageClassificationWithTeacher
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger()
+
+
+def convert_weight_and_push(
+ hidden_sizes: int, name: str, config: LevitConfig, save_directory: Path, push_to_hub: bool = True
+):
+ print(f"Converting {name}...")
+
+ with torch.no_grad():
+ if hidden_sizes == 128:
+ if name[-1] == "S":
+ from_model = timm.create_model("levit_128s", pretrained=True)
+ else:
+ from_model = timm.create_model("levit_128", pretrained=True)
+ if hidden_sizes == 192:
+ from_model = timm.create_model("levit_192", pretrained=True)
+ if hidden_sizes == 256:
+ from_model = timm.create_model("levit_256", pretrained=True)
+ if hidden_sizes == 384:
+ from_model = timm.create_model("levit_384", pretrained=True)
+
+ from_model.eval()
+ our_model = LevitForImageClassificationWithTeacher(config).eval()
+ huggingface_weights = OrderedDict()
+
+ weights = from_model.state_dict()
+ og_keys = list(from_model.state_dict().keys())
+ new_keys = list(our_model.state_dict().keys())
+ print(len(og_keys), len(new_keys))
+ for i in range(len(og_keys)):
+ huggingface_weights[new_keys[i]] = weights[og_keys[i]]
+ our_model.load_state_dict(huggingface_weights)
+
+ x = torch.randn((2, 3, 224, 224))
+ out1 = from_model(x)
+ out2 = our_model(x).logits
+
+ assert torch.allclose(out1, out2), "The model logits don't match the original one."
+
+ checkpoint_name = name
+ print(checkpoint_name)
+
+ if push_to_hub:
+ our_model.save_pretrained(save_directory / checkpoint_name)
+ feature_extractor = LevitFeatureExtractor()
+ feature_extractor.save_pretrained(save_directory / checkpoint_name)
+
+ print(f"Pushed {checkpoint_name}")
+
+
+def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
+ filename = "imagenet-1k-id2label.json"
+ num_labels = 1000
+ expected_shape = (1, num_labels)
+
+ repo_id = "datasets/huggingface/label-files"
+ num_labels = num_labels
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+
+ id2label = id2label
+ label2id = {v: k for k, v in id2label.items()}
+
+ ImageNetPreTrainedConfig = partial(LevitConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
+
+ names_to_hidden_sizes = {
+ "levit-128S": 128,
+ "levit-128": 128,
+ "levit-192": 192,
+ "levit-256": 256,
+ "levit-384": 384,
+ }
+
+ names_to_config = {
+ "levit-128S": ImageNetPreTrainedConfig(
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 6, 8],
+ depths=[2, 3, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ ),
+ "levit-128": ImageNetPreTrainedConfig(
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 8, 12],
+ depths=[4, 4, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ ),
+ "levit-192": ImageNetPreTrainedConfig(
+ hidden_sizes=[192, 288, 384],
+ num_attention_heads=[3, 5, 6],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0,
+ ),
+ "levit-256": ImageNetPreTrainedConfig(
+ hidden_sizes=[256, 384, 512],
+ num_attention_heads=[4, 6, 8],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0,
+ ),
+ "levit-384": ImageNetPreTrainedConfig(
+ hidden_sizes=[384, 512, 768],
+ num_attention_heads=[6, 9, 12],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0.1,
+ ),
+ }
+
+ if model_name:
+ convert_weight_and_push(
+ names_to_hidden_sizes[model_name], model_name, names_to_config[model_name], save_directory, push_to_hub
+ )
+ else:
+ for model_name, config in names_to_config.items():
+ convert_weight_and_push(names_to_hidden_sizes[model_name], model_name, config, save_directory, push_to_hub)
+ return config, expected_shape
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default=None,
+ type=str,
+ help="The name of the model you wish to convert, it must be one of the supported Levit* architecture,",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="levit-dump-folder/",
+ type=Path,
+ required=False,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ default=True,
+ type=bool,
+ required=False,
+ help="If True, push model and feature extractor to the hub.",
+ )
+
+ args = parser.parse_args()
+ pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
+ pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
+ convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
diff --git a/src/transformers/models/levit/feature_extraction_levit.py b/src/transformers/models/levit/feature_extraction_levit.py
new file mode 100644
index 000000000000..b0ac5f6b3d30
--- /dev/null
+++ b/src/transformers/models/levit/feature_extraction_levit.py
@@ -0,0 +1,158 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for LeViT."""
+
+from typing import Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ImageFeatureExtractionMixin,
+ ImageInput,
+ is_torch_tensor,
+)
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LevitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a LeViT feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the shortest edge of the input to int(256/224 *`size`).
+ size (`int` or `Tuple(int)`, *optional*, defaults to 224):
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
+ integer is provided, then shorter side of input will be resized to 'size'.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether or not to center crop the input to `size`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BICUBIC,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=IMAGENET_DEFAULT_MEAN,
+ image_std=IMAGENET_DEFAULT_STD,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def __call__(
+ self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None:
+ size_ = int((256 / 224) * self.size)
+ images = [
+ self.resize(image=image, size=size_, resample=self.resample, default_to_square=False)
+ for image in images
+ ]
+ if self.do_center_crop:
+ images = [self.center_crop(image=image, size=self.size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py
new file mode 100644
index 000000000000..581edf7d7c6c
--- /dev/null
+++ b/src/transformers/models/levit/modeling_levit.py
@@ -0,0 +1,744 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LeViT model."""
+
+import itertools
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+ ModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_levit import LevitConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "LevitConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "LevitFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/levit-128S"
+_EXPECTED_OUTPUT_SHAPE = [1, 16, 384]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/levit-128S",
+ # See all LeViT models at https://huggingface.co/models?filter=levit
+]
+
+
+@dataclass
+class LevitForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Output type of [`LevitForImageClassificationWithTeacher`].
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the `cls_logits` and `distillation_logits`.
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ """
+
+ logits: torch.FloatTensor = None
+ cls_logits: torch.FloatTensor = None
+ distillation_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class LevitConvEmbeddings(nn.Module):
+ """
+ LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
+ ):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
+ )
+ self.batch_norm = nn.BatchNorm2d(out_channels)
+
+ def forward(self, embeddings):
+ embeddings = self.convolution(embeddings)
+ embeddings = self.batch_norm(embeddings)
+ return embeddings
+
+
+class LevitPatchEmbeddings(nn.Module):
+ """
+ LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
+ `LevitConvEmbeddings`.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.embedding_layer_1 = LevitConvEmbeddings(
+ config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_1 = nn.Hardswish()
+
+ self.embedding_layer_2 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_2 = nn.Hardswish()
+
+ self.embedding_layer_3 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_3 = nn.Hardswish()
+
+ self.embedding_layer_4 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
+ )
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values):
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.embedding_layer_1(pixel_values)
+ embeddings = self.activation_layer_1(embeddings)
+ embeddings = self.embedding_layer_2(embeddings)
+ embeddings = self.activation_layer_2(embeddings)
+ embeddings = self.embedding_layer_3(embeddings)
+ embeddings = self.activation_layer_3(embeddings)
+ embeddings = self.embedding_layer_4(embeddings)
+ return embeddings.flatten(2).transpose(1, 2)
+
+
+class MLPLayerWithBN(nn.Module):
+ def __init__(self, input_dim, output_dim, bn_weight_init=1):
+ super().__init__()
+ self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
+ self.batch_norm = nn.BatchNorm1d(output_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.linear(hidden_state)
+ hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
+ return hidden_state
+
+
+class LevitSubsample(nn.Module):
+ def __init__(self, stride, resolution):
+ super().__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, hidden_state):
+ batch_size, _, channels = hidden_state.shape
+ hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
+ :, :: self.stride, :: self.stride
+ ].reshape(batch_size, -1, channels)
+ return hidden_state
+
+
+class LevitAttention(nn.Module):
+ def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
+ self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
+
+ self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
+ self.activation = nn.Hardswish()
+ self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ len_points = len(points)
+ attention_offsets, indices = {}, []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ indices.append(attention_offsets[offset])
+
+ self.attention_bias_cache = {}
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device):
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, hidden_state):
+ batch_size, seq_length, _ = hidden_state.shape
+ queries_keys_values = self.queries_keys_values(hidden_state)
+ query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
+ [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
+ )
+ query = query.permute(0, 2, 1, 3)
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
+ attention = attention.softmax(dim=-1)
+ hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
+ hidden_state = self.projection(self.activation(hidden_state))
+ return hidden_state
+
+
+class LevitAttentionSubsample(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ key_dim,
+ num_attention_heads,
+ attention_ratio,
+ stride,
+ resolution_in,
+ resolution_out,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
+ self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
+ self.resolution_out = resolution_out
+ # resolution_in is the intial resolution, resoloution_out is final resolution after downsampling
+ self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
+ self.queries_subsample = LevitSubsample(stride, resolution_in)
+ self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
+ self.activation = nn.Hardswish()
+ self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
+
+ self.attention_bias_cache = {}
+
+ points = list(itertools.product(range(resolution_in), range(resolution_in)))
+ points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
+ len_points, len_points_ = len(points), len(points_)
+ attention_offsets, indices = {}, []
+ for p1 in points_:
+ for p2 in points:
+ size = 1
+ offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ indices.append(attention_offsets[offset])
+
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device):
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, hidden_state):
+ batch_size, seq_length, _ = hidden_state.shape
+ key, value = (
+ self.keys_values(hidden_state)
+ .view(batch_size, seq_length, self.num_attention_heads, -1)
+ .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
+ )
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ query = self.queries(self.queries_subsample(hidden_state))
+ query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
+ 0, 2, 1, 3
+ )
+
+ attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
+ attention = attention.softmax(dim=-1)
+ hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
+ hidden_state = self.projection(self.activation(hidden_state))
+ return hidden_state
+
+
+class LevitMLPLayer(nn.Module):
+ """
+ MLP Layer with `2X` expansion in contrast to ViT with `4X`.
+ """
+
+ def __init__(self, input_dim, hidden_dim):
+ super().__init__()
+ self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
+ self.activation = nn.Hardswish()
+ self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.linear_up(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.linear_down(hidden_state)
+ return hidden_state
+
+
+class LevitResidualLayer(nn.Module):
+ """
+ Residual Block for LeViT
+ """
+
+ def __init__(self, module, drop_rate):
+ super().__init__()
+ self.module = module
+ self.drop_rate = drop_rate
+
+ def forward(self, hidden_state):
+ if self.training and self.drop_rate > 0:
+ rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
+ rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
+ hidden_state = hidden_state + self.module(hidden_state) * rnd
+ return hidden_state
+ else:
+ hidden_state = hidden_state + self.module(hidden_state)
+ return hidden_state
+
+
+class LevitStage(nn.Module):
+ """
+ LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
+ """
+
+ def __init__(
+ self,
+ config,
+ idx,
+ hidden_sizes,
+ key_dim,
+ depths,
+ num_attention_heads,
+ attention_ratio,
+ mlp_ratio,
+ down_ops,
+ resolution_in,
+ ):
+ super().__init__()
+ self.layers = []
+ self.config = config
+ self.resolution_in = resolution_in
+ # resolution_in is the intial resolution, resolution_out is final resolution after downsampling
+ for _ in range(depths):
+ self.layers.append(
+ LevitResidualLayer(
+ LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
+ self.config.drop_path_rate,
+ )
+ )
+ if mlp_ratio > 0:
+ hidden_dim = hidden_sizes * mlp_ratio
+ self.layers.append(
+ LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
+ )
+
+ if down_ops[0] == "Subsample":
+ self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
+ self.layers.append(
+ LevitAttentionSubsample(
+ *self.config.hidden_sizes[idx : idx + 2],
+ key_dim=down_ops[1],
+ num_attention_heads=down_ops[2],
+ attention_ratio=down_ops[3],
+ stride=down_ops[5],
+ resolution_in=resolution_in,
+ resolution_out=self.resolution_out,
+ )
+ )
+ self.resolution_in = self.resolution_out
+ if down_ops[4] > 0:
+ hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
+ self.layers.append(
+ LevitResidualLayer(
+ LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
+ )
+ )
+
+ self.layers = nn.ModuleList(self.layers)
+
+ def get_resolution(self):
+ return self.resolution_in
+
+ def forward(self, hidden_state):
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class LevitEncoder(nn.Module):
+ """
+ LeViT Encoder consisting of multiple `LevitStage` stages.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ resolution = self.config.image_size // self.config.patch_size
+ self.stages = []
+ self.config.down_ops.append([""])
+
+ for stage_idx in range(len(config.depths)):
+ stage = LevitStage(
+ config,
+ stage_idx,
+ config.hidden_sizes[stage_idx],
+ config.key_dim[stage_idx],
+ config.depths[stage_idx],
+ config.num_attention_heads[stage_idx],
+ config.attention_ratio[stage_idx],
+ config.mlp_ratio[stage_idx],
+ config.down_ops[stage_idx],
+ resolution,
+ )
+ resolution = stage.get_resolution()
+ self.stages.append(stage)
+
+ self.stages = nn.ModuleList(self.stages)
+
+ def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ for stage in self.stages:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+ hidden_state = stage(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+ if not return_dict:
+ return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
+
+
+class LevitClassificationLayer(nn.Module):
+ """
+ LeViT Classification Layer
+ """
+
+ def __init__(self, input_dim, output_dim):
+ super().__init__()
+ self.batch_norm = nn.BatchNorm1d(input_dim)
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.batch_norm(hidden_state)
+ logits = self.linear(hidden_state)
+ return logits
+
+
+class LevitPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LevitConfig
+ base_model_prefix = "levit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LevitModel):
+ module.gradient_checkpointing = value
+
+
+LEVIT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`LevitConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LEVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Levit model outputting raw features without any specific head on top.",
+ LEVIT_START_DOCSTRING,
+)
+class LevitModel(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.patch_embeddings = LevitPatchEmbeddings(config)
+ self.encoder = LevitEncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embeddings = self.patch_embeddings(pixel_values)
+ encoder_outputs = self.encoder(
+ embeddings,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
+ pooled_output = last_hidden_state.mean(dim=1)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ LEVIT_START_DOCSTRING,
+)
+class LevitForImageClassification(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.num_labels = config.num_labels
+ self.levit = LevitModel(config)
+
+ # Classifier head
+ self.classifier = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ sequence_output = outputs[0]
+ sequence_output = sequence_output.mean(1)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
+ a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """,
+ LEVIT_START_DOCSTRING,
+)
+class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.num_labels = config.num_labels
+ self.levit = LevitModel(config)
+
+ # Classifier head
+ self.classifier = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+ self.classifier_distill = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=LevitForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ sequence_output = outputs[0]
+ sequence_output = sequence_output.mean(1)
+ cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
+ logits = (cls_logits + distill_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distill_logits) + outputs[2:]
+ return output
+
+ return LevitForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distill_logits,
+ hidden_states=outputs.hidden_states,
+ )
diff --git a/src/transformers/models/longformer/__init__.py b/src/transformers/models/longformer/__init__.py
index 329b8f1cdf92..1705703b5ac3 100644
--- a/src/transformers/models/longformer/__init__.py
+++ b/src/transformers/models/longformer/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -30,10 +36,20 @@
"tokenization_longformer": ["LongformerTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_longformer_fast"] = ["LongformerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_longformer"] = [
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"LongformerForMaskedLM",
@@ -46,7 +62,12 @@
"LongformerSelfAttention",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_longformer"] = [
"TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLongformerForMaskedLM",
@@ -68,10 +89,20 @@
)
from .tokenization_longformer import LongformerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_longformer_fast import LongformerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
@@ -84,7 +115,12 @@
LongformerSelfAttention,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
diff --git a/src/transformers/models/longformer/configuration_longformer.py b/src/transformers/models/longformer/configuration_longformer.py
index 2c9fd17b35ec..53ceeafb64ba 100644
--- a/src/transformers/models/longformer/configuration_longformer.py
+++ b/src/transformers/models/longformer/configuration_longformer.py
@@ -24,9 +24,15 @@
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json",
"allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/config.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/config.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/config.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/config.json",
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/config.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/config.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py
index 20a2e9d239e2..7661f90bfbb4 100755
--- a/src/transformers/models/longformer/modeling_longformer.py
+++ b/src/transformers/models/longformer/modeling_longformer.py
@@ -388,9 +388,10 @@ def _get_question_end_index(input_ids, sep_token_id):
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
- assert (
- sep_token_indices.shape[0] == 3 * batch_size
- ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
+ assert sep_token_indices.shape[0] == 3 * batch_size, (
+ f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You"
+ " might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
+ )
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@@ -446,8 +447,6 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.padding_idx = config.pad_token_id
@@ -468,13 +467,8 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs
else:
input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
-
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
-
if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
@@ -585,7 +579,7 @@ def forward(
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
- remove_from_windowed_attention_mask, -10000.0
+ remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
@@ -600,7 +594,10 @@ def forward(
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
- ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ ], (
+ f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ )
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
@@ -954,7 +951,7 @@ def _concat_with_global_key_attn_probs(
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
- ] = -10000.0
+ ] = torch.finfo(attn_probs_from_global_key.dtype).min
return attn_probs_from_global_key
@@ -1040,17 +1037,21 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
- ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
+ ], (
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {global_attn_scores.size()}."
+ )
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
- ] = -10000.0
+ ] = torch.finfo(global_attn_scores.dtype).min
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
- -10000.0,
+ torch.finfo(global_attn_scores.dtype).min,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
@@ -1083,7 +1084,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
- ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
+ ], (
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {global_attn_output.size()}."
+ )
global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_output = global_attn_output.view(
@@ -1380,7 +1385,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
- _keys_to_ignore_on_load_missing = [r"position_ids"]
+ _keys_to_ignore_on_load_unexpected = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
@@ -1770,23 +1775,31 @@ def forward(
Returns:
- Examples:
+ Mask filling example:
```python
- >>> import torch
- >>> from transformers import LongformerForMaskedLM, LongformerTokenizer
+ >>> from transformers import LongformerTokenizer, LongformerForMaskedLM
- >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
+ >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
+ ```
- >>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
- >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
+ Let's try a very long input.
- >>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
- >>> # check `LongformerModel.forward` for more details how to set *attention_mask*
- >>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
- >>> loss = outputs.loss
- >>> prediction_logits = outputs.logits
+ ```python
+ >>> TXT = (
+ ... "My friends are but they eat too many carbs."
+ ... + " That's why I decide not to eat with them." * 300
+ ... )
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = logits[0, masked_index].softmax(dim=0)
+ >>> values, predictions = probs.topk(5)
+
+ >>> tokenizer.decode(predictions).split()
+ ['healthy', 'skinny', 'thin', 'good', 'vegetarian']
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1848,9 +1861,11 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="jpwahle/longformer-base-plagiarism-detection",
output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="'ORIGINAL'",
+ expected_loss=5.44,
)
def forward(
self,
@@ -2032,7 +2047,8 @@ def forward(
if global_attention_mask is None:
if input_ids is None:
logger.warning(
- "It is not possible to automatically generate the `global_attention_mask` because input_ids is None. Please make sure that it is correctly set."
+ "It is not possible to automatically generate the `global_attention_mask` because input_ids is"
+ " None. Please make sure that it is correctly set."
)
else:
# set global attention on question tokens automatically
@@ -2114,9 +2130,14 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="brad1141/Longformer-finetuned-norm",
output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output=(
+ "['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence',"
+ " 'Evidence', 'Evidence', 'Evidence', 'Evidence']"
+ ),
+ expected_loss=0.63,
)
def forward(
self,
diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py
index 124fe2c06fec..0dfd9c66617f 100644
--- a/src/transformers/models/longformer/modeling_tf_longformer.py
+++ b/src/transformers/models/longformer/modeling_tf_longformer.py
@@ -775,7 +775,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
- message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
+ message=(
+ f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
+ ),
)
# compute global attn indices required through out forward fn
@@ -828,7 +831,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@@ -921,7 +927,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
- message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
+ message=(
+ f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
+ f" {shape_list(key)}"
+ ),
)
chunks_count = seq_len // window_overlap - 1
@@ -1206,7 +1215,10 @@ def _chunk(hidden_states, window_overlap):
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
- message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
+ message=(
+ "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
+ f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
+ ),
)
chunked_hidden_states = tf.reshape(
@@ -1384,7 +1396,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
- message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
+ message=(
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {shape_list(global_attn_scores)}."
+ ),
)
global_attn_scores = tf.reshape(
@@ -1423,7 +1439,10 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -1442,7 +1461,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
- message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
+ message=(
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {shape_list(global_attn_output)}."
+ ),
)
global_attn_output = tf.reshape(
@@ -2079,10 +2102,12 @@ def get_prefix_bias_name(self):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="",
+ expected_output="' Paris'",
+ expected_loss=0.44,
)
def call(
self,
@@ -2175,6 +2200,8 @@ def __init__(self, config, *inputs, **kwargs):
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
output_type=TFLongformerQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="' puppet'",
+ expected_loss=0.96,
)
def call(
self,
@@ -2207,7 +2234,10 @@ def call(
if global_attention_mask is None and input_ids is not None:
if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]:
logger.warning(
- f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
+ f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for"
+ " questions answering. You might also consider to set `global_attention_mask` manually in the"
+ " forward function to avoid this. This is most likely an error. The global attention is disabled"
+ " for this forward pass."
)
global_attention_mask = tf.fill(shape_list(input_ids), value=0)
else:
@@ -2318,9 +2348,11 @@ def __init__(self, config, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="hf-internal-testing/tiny-random-longformer",
output_type=TFLongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="'LABEL_1'",
+ expected_loss=0.69,
)
def call(
self,
@@ -2556,9 +2588,15 @@ def __init__(self, config, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="hf-internal-testing/tiny-random-longformer",
output_type=TFLongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output=(
+ "['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
+ " 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
+ " 'LABEL_1', 'LABEL_1']"
+ ),
+ expected_loss=0.59,
)
def call(
self,
diff --git a/src/transformers/models/longformer/tokenization_longformer.py b/src/transformers/models/longformer/tokenization_longformer.py
index 19445622b821..b594580647a2 100644
--- a/src/transformers/models/longformer/tokenization_longformer.py
+++ b/src/transformers/models/longformer/tokenization_longformer.py
@@ -25,17 +25,33 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
},
}
diff --git a/src/transformers/models/longformer/tokenization_longformer_fast.py b/src/transformers/models/longformer/tokenization_longformer_fast.py
index a7d06b1fc3db..45a888397117 100644
--- a/src/transformers/models/longformer/tokenization_longformer_fast.py
+++ b/src/transformers/models/longformer/tokenization_longformer_fast.py
@@ -26,24 +26,50 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
},
"tokenizer_file": {
- "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json",
+ "allenai/longformer-base-4096": (
+ "https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/longt5/__init__.py b/src/transformers/models/longt5/__init__.py
new file mode 100644
index 000000000000..fd355f6d5a93
--- /dev/null
+++ b/src/transformers/models/longt5/__init__.py
@@ -0,0 +1,88 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config", "LongT5OnnxConfig"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_longt5"] = [
+ "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LongT5EncoderModel",
+ "LongT5ForConditionalGeneration",
+ "LongT5Model",
+ "LongT5PreTrainedModel",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_longt5"] = [
+ "FlaxLongT5ForConditionalGeneration",
+ "FlaxLongT5Model",
+ "FlaxLongT5PreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config, LongT5OnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_longt5 import (
+ LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ LongT5PreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_longt5 import (
+ FlaxLongT5ForConditionalGeneration,
+ FlaxLongT5Model,
+ FlaxLongT5PreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py
new file mode 100644
index 000000000000..705fdc493958
--- /dev/null
+++ b/src/transformers/models/longt5/configuration_longt5.py
@@ -0,0 +1,178 @@
+# coding=utf-8
+# Copyright 2022, The LongT5 Authors and HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LongT5 model configuration"""
+from typing import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxSeq2SeqConfigWithPast
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "google/long-t5-local-base": "https://huggingface.co/google/long-t5-local-base/blob/main/config.json",
+ "google/long-t5-local-large": "https://huggingface.co/google/long-t5-local-large/blob/main/config.json",
+ "google/long-t5-tglobal-base": "https://huggingface.co/google/long-t5-tglobal-base/blob/main/config.json",
+ "google/long-t5-tglobal-large": "https://huggingface.co/google/long-t5-tglobal-large/blob/main/config.json",
+}
+
+
+class LongT5Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
+ used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
+ [google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Arguments:
+ vocab_size (`int`, *optional*, defaults to 32128):
+ Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LongT5Model`].
+ d_model (`int`, *optional*, defaults to 512):
+ Size of the encoder layers and the pooler layer.
+ d_kv (`int`, *optional*, defaults to 64):
+ Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
+ num_heads`.
+ d_ff (`int`, *optional*, defaults to 2048):
+ Size of the intermediate feed forward layer in each `LongT5Block`.
+ num_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ num_decoder_layers (`int`, *optional*):
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
+ num_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ local_radius (`int`, *optional*, defaults to 127)
+ Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism.
+ global_block_size (`int`, *optional*, defaults to 16)
+ Lenght of blocks an input sequence is divided into for a global token representation. Used only for
+ `encoder_attention_type = "transient-global"`.
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
+ The number of buckets to use for each attention layer.
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
+ The maximum distance of the longer sequences for the bucket separation.
+ dropout_rate (`float`, *optional*, defaults to 0.1):
+ The ratio for all dropout layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the
+ `"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`.
+ encoder_attention_type (`string`, *optional*, defaults to `"local"`):
+ Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are
+ supported by LongT5 implementation.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ """
+ model_type = "longt5"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
+
+ def __init__(
+ self,
+ vocab_size=32128,
+ d_model=512,
+ d_kv=64,
+ d_ff=2048,
+ num_layers=6,
+ num_decoder_layers=None,
+ num_heads=8,
+ local_radius=127,
+ global_block_size=16,
+ relative_attention_num_buckets=32,
+ relative_attention_max_distance=128,
+ dropout_rate=0.1,
+ layer_norm_epsilon=1e-6,
+ initializer_factor=1.0,
+ feed_forward_proj="relu",
+ is_encoder_decoder=True,
+ encoder_attention_type="local",
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ **kwargs
+ ):
+
+ self.vocab_size = vocab_size
+ self.d_model = d_model
+ self.d_kv = d_kv
+ self.d_ff = d_ff
+ self.num_layers = num_layers
+ # default = symmetry
+ self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
+ self.num_heads = num_heads
+ self.local_radius = local_radius
+ self.global_block_size = global_block_size
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.relative_attention_max_distance = relative_attention_max_distance
+ self.dropout_rate = dropout_rate
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_factor = initializer_factor
+ self.feed_forward_proj = feed_forward_proj
+ self.encoder_attention_type = encoder_attention_type
+ self.use_cache = use_cache
+
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+
+class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = {
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
+ }
+ if self.use_past:
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
+ common_inputs["decoder_input_ids"] = {0: "batch"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
+ else:
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
+
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+
+ return common_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
diff --git a/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py
new file mode 100644
index 000000000000..41cc3a2005dd
--- /dev/null
+++ b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py
@@ -0,0 +1,214 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of
+'src/transformers/models/t5/convert_t5x_checkpoint_to_flax.
+"""
+
+import argparse
+
+from t5x import checkpoints
+from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
+
+
+def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
+ config = AutoConfig.from_pretrained(config_name)
+ flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
+
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
+
+ if config.model_type == "t5":
+ encoder_attn_name = "SelfAttention"
+ if config.model_type == "longt5" and config.encoder_attention_type == "local":
+ encoder_attn_name = "LocalSelfAttention"
+ elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ encoder_attn_name = "TransientGlobalSelfAttention"
+ else:
+ raise ValueError(
+ "Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
+ " attribute with a value from ['local', 'transient-global]."
+ )
+
+ # Encoder
+ for layer_index in range(config.num_layers):
+ layer_name = f"layers_{str(layer_index)}"
+
+ # Self-Attention
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
+
+ # Global input layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"]
+
+ # Layer Normalization
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
+
+ if split_mlp_wi:
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
+ else:
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
+
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
+
+ # Layer Normalization
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
+
+ # Assigning
+ flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"]
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value
+
+ flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
+
+ # Global input layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"][
+ "weight"
+ ] = t5x_global_layer_norm
+
+ if split_mlp_wi:
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
+ else:
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
+
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
+ flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
+
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block
+
+ # Only for layer 0:
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
+ "embedding"
+ ] = t5x_encoder_rel_embedding
+
+ # Side/global relative position_bias + layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
+ "embedding"
+ ] = t5x_encoder_global_rel_embedding
+
+ # Assigning
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
+
+ # Decoder
+ for layer_index in range(config.num_layers):
+ layer_name = f"layers_{str(layer_index)}"
+
+ # Self-Attention
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
+
+ # Layer Normalization
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][
+ "scale"
+ ]
+
+ # Encoder-Decoder-Attention
+ t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]
+ t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"]
+ t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"]
+ t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"]
+ t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"]
+
+ # Layer Normalization
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
+
+ # MLP
+ if split_mlp_wi:
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
+ else:
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
+
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
+
+ # Layer Normalization
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
+
+ # Assigning
+ flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"]
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
+
+ flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
+
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
+
+ flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
+
+ if split_mlp_wi:
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
+ else:
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
+
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
+
+ flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
+
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block
+
+ # Decoder Normalization
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
+
+ # Only for layer 0:
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
+ "embedding"
+ ] = t5x_decoder_rel_embedding
+
+ # Token Embeddings
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
+
+ # LM Head (only in v1.1 and LongT5 checkpoints)
+ if "logits_dense" in t5x_model["target"]["decoder"]:
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
+
+ flax_model.save_pretrained(flax_dump_folder_path)
+ print("T5X Model was sucessfully converted!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
+ )
+ parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
+ parser.add_argument(
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
+ )
+ args = parser.parse_args()
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py
new file mode 100644
index 000000000000..766dc36888e2
--- /dev/null
+++ b/src/transformers/models/longt5/modeling_flax_longt5.py
@@ -0,0 +1,2402 @@
+# coding=utf-8
+# Copyright 2022 LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Flax LongT5 model."""
+
+
+import copy
+from typing import Any, Callable, List, Optional, Tuple
+
+import numpy as np
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxSeq2SeqLMOutput,
+ FlaxSeq2SeqModelOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_longt5 import LongT5Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
+_CONFIG_FOR_DOC = "LongT5Config"
+_TOKENIZER_FOR_DOC = "T5Tokenizer"
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = np.zeros_like(input_ids)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
+def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray:
+ """Pad an array so that a sequence length will be a multiple of `block_len`"""
+ pad_len = -x.shape[axis] % block_len
+ pad = [(0, 0)] * x.ndim
+ pad[axis] = (0, pad_len)
+ x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
+ return x
+
+
+def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray:
+ """Split an input array into blocks of a given `block_len` along the given `axis`. If the dimension length
+ is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
+ """
+ # pad tensor to multiple of block_len
+ if x.shape[axis] % block_len != 0:
+ x = _pad_to_multiple(x, block_len, axis, pad_value=0)
+ num_blocks = x.shape[axis] // block_len
+ output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1) :]
+ return x.reshape(output_shape)
+
+
+def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray:
+ """Concatenate three consecutive blocks for each input block for local attentiont.
+ For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
+ """
+ num_blocks = x.shape[block_axis]
+
+ pad = [(0, 0)] * x.ndim
+ pad[block_axis] = (1, 1)
+ # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
+ x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
+
+ blocks_list: List[np.array] = []
+ for i in range(3):
+ # We use indexing approach here:
+ # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
+ indices = [slice(0, None)] * x.ndim
+ indices[block_axis] = slice(i, i + num_blocks)
+ indices = tuple(indices)
+ blocks_list.append(x[indices])
+ return jnp.concatenate(blocks_list, axis=sequence_axis) # [batch_size, num_blocks, 3 * block_len, ...]
+
+
+def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:
+ """Makes 3-blocked relative position ids for local attention."""
+ position_ids = jnp.arange(3 * block_len, dtype=jnp.int32)
+ center_position_ids = position_ids[block_len:-block_len]
+ relative_position_ids = position_ids[None, :] - center_position_ids[:, None] # [block_len, 3 * block_len]
+ return relative_position_ids
+
+
+def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
+ """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
+ relative_position_ids = _make_3block_relative_position_ids(block_len)
+ locality_mask = jnp.abs(relative_position_ids) < block_len
+ locality_mask = locality_mask[None, None, :, :]
+ return jnp.logical_and(local_attention_mask, locality_mask)
+
+
+def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
+ """Prepare attention mask to be applied for a local attention."""
+ # [batch_size, num_blocks, block_len]
+ _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1)
+ # [batch_size, num_block, 3 * block_len]
+ _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2)
+
+ _blocked_attention_mask = _blocked_attention_mask[..., None]
+ _3blocked_attention_mask = _3blocked_attention_mask[..., None, :]
+ # [batch_size, num_block, block_len, 3 * block_len]
+ local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
+ local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
+ # [batch_size, 1, num_block, block_len, 3 * block_len]
+ return local_attention_mask[:, None, ...]
+
+
+def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]:
+ """Obtain the "fixed block" global id corresponding to each input token.
+
+ This implementation is a simlified version of the original Flaxformr implementation adopted from:
+ https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
+
+ In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
+ the whole fixed block, are assigned to the preceding block.
+
+ Padding tokens from the original sequence are represented by -1.
+ """
+ batch_size, seq_len = attention_mask.shape[:2]
+
+ def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray:
+ block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1
+ true_block_ends = jnp.logical_and(block_ends, block_ids >= 0)
+ full_blocks = true_block_ends.sum(-1)[..., None]
+ block_ids = jnp.minimum(block_ids, full_blocks - 1)
+ return block_ids
+
+ fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size
+ fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
+ mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0)
+ global_block_ids = jnp.maximum(
+ jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype)
+ )
+ # set padding tokens to -1
+ global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
+ # [batch_size, seq_len]
+ global_block_ids = handle_orphan_tokens(global_block_ids)
+ num_globals = seq_len // global_block_size
+
+ # [batch_size, seq_len // global_block_size]
+ if num_globals > 0:
+ _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1)
+ else:
+ _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype)
+ global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1
+ global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
+ return global_block_ids, global_segment_ids
+
+
+def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray:
+ """Create the relative position tensor for local -> global attention."""
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
+ global_seq_len = global_segment_ids.shape[-1]
+ global_positions = jnp.arange(global_seq_len)
+ side_relative_position = global_positions - block_ids[..., None]
+ return side_relative_position
+
+
+def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray:
+ """Compute individual block aggregates by summing over individual blocks."""
+ # (batch..., seq_len, global_seq_len))
+ one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len)
+ return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids)
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->LongT5
+class FlaxLongT5LayerNorm(nn.Module):
+ hidden_size: int
+ dtype: jnp.dtype = jnp.float32
+ eps: float = 1e-6
+ weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
+
+ def setup(self):
+ self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
+
+ def __call__(self, hidden_states):
+ """
+ Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean.
+ """
+ # layer norm should always be calculated in float32
+ variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
+ hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
+
+ return self.weight * hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->LongT5
+class FlaxLongT5DenseActDense(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
+ wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
+
+ self.wi = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wo = nn.Dense(
+ self.config.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wo_init_std),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->LongT5
+class FlaxLongT5DenseGatedActDense(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
+ wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
+
+ self.wi_0 = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wi_1 = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wo = nn.Dense(
+ self.config.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wo_init_std),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
+
+ def __call__(self, hidden_states, deterministic):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->LongT5
+class FlaxLongT5LayerFF(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ if self.config.is_gated_act:
+ self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype)
+ else:
+ self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype)
+
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(self, hidden_states, deterministic=True):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
+ hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->LongT5
+class FlaxLongT5Attention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, query_length, key_length):
+ """Compute binned relative position bias"""
+ context_position = jnp.arange(query_length, dtype="i4")[:, None]
+ memory_position = jnp.arange(key_length, dtype="i4")[None, :]
+
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=(not self.causal),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, :, :, :]
+ return values
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions
+ # that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def _create_position_bias(
+ self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
+ ):
+ cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
+ key_length = key_states.shape[1]
+ query_length = key_length if cache_is_filled else query_states.shape[1]
+
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(query_length, key_length)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
+
+ # if key and values are already calculated, only the last query position bias should be taken
+ if cache_is_filled:
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ position_bias = jax.lax.dynamic_slice(
+ position_bias,
+ (0, 0, causal_attention_mask_shift, 0),
+ (1, self.n_heads, seq_length, max_decoder_length),
+ )
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ use_cache=False,
+ output_attentions=False,
+ deterministic=True,
+ init_cache=False,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ # for fast decoding causal attention mask should be shifted
+ causal_attention_mask_shift = (
+ self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0
+ )
+ # create causal attention_mask; attention_mask has to be defined when model is causal
+ if self.causal:
+ causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
+
+ # fast decoding for generate requires special attention_mask
+ if self.has_variable("cache", "cached_key"):
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_attention_mask = jax.lax.dynamic_slice(
+ causal_attention_mask,
+ (0, 0, causal_attention_mask_shift, 0),
+ (1, 1, seq_length, max_decoder_length),
+ )
+
+ # broadcast causal attention mask & attention mask to fit for merge
+ causal_attention_mask = jnp.broadcast_to(
+ causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
+ )
+ attention_mask = jnp.broadcast_to(
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape
+ )
+ attention_mask = combine_masks(attention_mask, causal_attention_mask)
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # replace masked positions with -10_000
+ if attention_mask is not None:
+ mask_value = jnp.finfo(self.dtype).min
+ attention_mask = jax.lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
+ )
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(
+ key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
+ )
+
+ if attention_mask is not None:
+ position_bias = position_bias + attention_mask
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5LocalAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.local_radius = self.config.local_radius
+ self.block_len = self.local_radius + 1
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = jnp.arange(3 * block_length, dtype="i4")
+ context_position = memory_position[block_length:-block_length]
+
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, None, :, :, :]
+ return values
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
+
+ def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(block_len)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
+
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
+ query_states = _split_into_blocks(query_states, self.block_len, axis=1)
+ key_states = _split_into_blocks(key_states, self.block_len, axis=1)
+ value_states = _split_into_blocks(value_states, self.block_len, axis=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
+ value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ if attention_mask is not None:
+ attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
+
+ # replace masked positions with -10_000
+ attention_mask = jax.lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
+ )
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(self.block_len, attention_mask)
+
+ if attention_mask is not None:
+ position_bias = position_bias + attention_mask.swapaxes(1, 2)
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = attn_output[:, :seq_length, :]
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5TransientGlobalAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.local_radius = self.config.local_radius
+ self.block_len = self.local_radius + 1
+ self.global_block_size = self.config.global_block_size
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ # Relativen attention bias & Layer norm for global attention
+ if self.has_relative_attention_bias:
+ self.global_relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+ self.global_input_layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = jnp.arange(3 * block_length, dtype="i4")
+ context_position = memory_position[block_length:-block_length]
+
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, None, :, :, :]
+ return values
+
+ def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray:
+ # (batch_size, 1, 1, seq_len, global_seq_len)
+ side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
+ attention_side_bias = jax.lax.select(
+ side_attention_mask > 0,
+ jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype),
+ )
+ # (batch_size, seq_len, global_seq_len)
+ side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size)
+ side_relative_position_bucket = self._relative_position_bucket(
+ side_relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (batch_size, seq_len, global_seq_len, num_heads)
+ side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
+
+ # (batch_size, 1, num_heads, seq_len, global_seq_len)
+ side_bias = jnp.transpose(side_bias, (0, 3, 1, 2))
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ attention_side_bias = attention_side_bias + side_bias
+ return attention_side_bias
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
+
+ def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(block_len)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
+
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # Prepare components for transient-global attention
+ # Obtain block_ids and global_segment_ids
+ # global_seq_len := seq_len // self.global_block_size
+ # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(
+ attention_mask if attention_mask is not None else jnp.ones((batch_size, seq_length)),
+ self.global_block_size,
+ )
+ # Create global inputs
+ _global_seq_len = global_segment_ids.shape[-1]
+ global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
+ global_inputs = self.global_input_layer_norm(global_inputs)
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # Get global/side key/value_states
+ side_key_states = self.k(global_inputs)
+ side_value_states = self.v(global_inputs)
+
+ # reshape to (batch_size, global_seq_len, n_heads, head_dim)
+ side_key_states = self._split_heads(side_key_states)
+ side_value_states = self._split_heads(side_value_states)
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
+ query_states = _split_into_blocks(query_states, self.block_len, axis=1)
+ key_states = _split_into_blocks(key_states, self.block_len, axis=1)
+ value_states = _split_into_blocks(value_states, self.block_len, axis=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
+ value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
+
+ # Tile side inputs across local key/value blocks
+ # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
+ reps = [1] * (side_key_states.ndim + 1)
+ reps[1] = key_states.shape[1]
+ side_key_states = jnp.tile(side_key_states[:, None, ...], reps)
+ side_value_states = jnp.tile(side_value_states[:, None, ...], reps)
+
+ # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
+ # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
+ key_states = jnp.concatenate((key_states, side_key_states), axis=2)
+ value_states = jnp.concatenate((value_states, side_value_states), axis=2)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ if attention_mask is not None:
+ local_attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
+ local_attention_mask = jax.lax.select(
+ local_attention_mask > 0,
+ jnp.full(local_attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(local_attention_mask.shape, -1e10).astype(self.dtype),
+ )
+ else:
+ local_attention_mask = None
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(self.block_len, attention_mask)
+ if local_attention_mask is not None:
+ position_bias = position_bias + local_attention_mask.swapaxes(1, 2)
+
+ # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, seq_length))
+ side_position_bias = self.compute_side_bias(attention_mask, global_segment_ids)
+ side_position_bias = _split_into_blocks(side_position_bias, self.block_len, axis=-2)
+ side_position_bias = jnp.swapaxes(side_position_bias, 1, 2)
+ position_bias = jnp.concatenate((position_bias, side_position_bias), axis=-1)
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = attn_output[:, :seq_length, :]
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5LayerLocalSelfAttention(nn.Module):
+ """Local self attention used in encoder"""
+
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.LocalSelfAttention = FlaxLongT5LocalAttention(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ **kwargs: Any, # to accept init_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.LocalSelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module):
+ """Transient-Global self attention used in encoder"""
+
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ **kwargs: Any, # to accept init_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.TransientGlobalSelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->LongT5
+class FlaxLongT5LayerSelfAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.SelfAttention = FlaxLongT5Attention(
+ self.config,
+ has_relative_attention_bias=self.has_relative_attention_bias,
+ causal=self.config.causal,
+ dtype=self.dtype,
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ init_cache=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->LongT5
+class FlaxLongT5LayerCrossAttention(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.EncDecAttention = FlaxLongT5Attention(
+ self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class FlaxLongT5Block(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+ if self.causal:
+ attention_layer = FlaxLongT5LayerSelfAttention
+ elif self.config.encoder_attention_type == "local":
+ attention_layer = FlaxLongT5LayerLocalSelfAttention
+ elif self.config.encoder_attention_type == "transient-global":
+ attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention
+ else:
+ raise ValueError(
+ "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
+ f"but got {self.config.encoder_attention_type}."
+ )
+ self.layer = (
+ attention_layer(
+ self.config,
+ has_relative_attention_bias=self.has_relative_attention_bias,
+ name=str(0),
+ dtype=self.dtype,
+ ),
+ )
+ feed_forward_index = 1
+ if self.causal:
+ self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
+ feed_forward_index += 1
+
+ self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
+
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ with T5->LongT5
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ output_attentions=False,
+ return_dict=True,
+ deterministic=True,
+ init_cache=False,
+ ):
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+ hidden_states = self_attention_outputs[0]
+ attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
+
+ do_cross_attention = self.causal and encoder_hidden_states is not None
+ if do_cross_attention:
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = cross_attention_outputs[0]
+
+ # Keep cross-attention outputs and relative position weights
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ outputs = outputs + attention_outputs
+
+ # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->LongT5
+class FlaxLongT5LayerCollection(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layer = FlaxLongT5Block(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ output_attentions=False,
+ return_dict=True,
+ deterministic=True,
+ init_cache=False,
+ ):
+ return self.layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->LongT5
+class FlaxLongT5BlockCollection(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+ self.blocks = [
+ FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
+ for i in range(self.config.num_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ ):
+ # Prepare head mask if needed
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.causal) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ for i, layer_module in enumerate(self.blocks):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[1]
+
+ if self.causal and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[2],)
+ if self.causal:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->LongT5
+class FlaxLongT5Stack(nn.Module):
+ config: LongT5Config
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+
+ self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype)
+ self.final_layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ ):
+ hidden_states = self.embed_tokens(input_ids)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ outputs = self.block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ hidden_states = outputs[0]
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ # Add last layer
+ all_hidden_states = None
+
+ if output_hidden_states:
+ all_hidden_states = outputs.hidden_states
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ if output_hidden_states:
+ return (
+ hidden_states,
+ all_hidden_states,
+ ) + outputs[2:]
+ return (hidden_states,) + outputs[1:]
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+LONGT5_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+LONGT5_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ For training, `decoder_input_ids` should be provided.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+LONGT5_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
+ Training](./longt5#training).
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*):
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LongT5Config
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: LongT5Config,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = jnp.ones_like(input_ids)
+ decoder_attention_mask = jnp.ones_like(input_ids)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: jnp.ndarray = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if decoder_input_ids is None:
+ raise ValueError(
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
+ " here."
+ )
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # prepare decoder inputs
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, **kwargs)
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+ >>> import jax.numpy as jnp
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxLongT5Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+LONGT5_START_DOCSTRING = r"""
+ The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
+ Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
+ Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
+ generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
+ efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5
+class FlaxLongT5Module(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def setup(self):
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
+ )
+
+ encoder_config = copy.deepcopy(self.config)
+ encoder_config.causal = False
+ self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
+
+ decoder_config = copy.deepcopy(self.config)
+ decoder_config.causal = True
+ decoder_config.num_layers = self.config.num_decoder_layers
+ self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ encoder_outputs=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ deterministic: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode if needed (training, first prediction pass)
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->LongT5
+class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
+ module_class = FlaxLongT5Module
+
+
+append_call_sample_docstring(
+ FlaxLongT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
+)
+
+FLAX_LONGT5_MODEL_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5Model
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base")
+
+ >>> input_ids = tokenizer(
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="np"
+ ... ).input_ids
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
+
+ >>> # forward pass
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+
+overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING)
+append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+
+
+@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->LongT5
+class FlaxLongT5ForConditionalGenerationModule(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def setup(self):
+ self.model_dim = self.config.d_model
+
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
+ )
+
+ encoder_config = copy.deepcopy(self.config)
+ encoder_config.causal = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype)
+
+ decoder_config = copy.deepcopy(self.config)
+ decoder_config.causal = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = self.config.num_decoder_layers
+ self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype)
+
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ encoder_outputs=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ deterministic: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim**-0.5)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.shared.variables["params"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
+ else:
+ lm_logits = self.lm_head(sequence_output)
+
+ if not return_dict:
+ return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+
+ return FlaxSeq2SeqLMOutput(
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
+ module_class = FlaxLongT5ForConditionalGenerationModule
+
+ @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+ >>> import jax.numpy as jnp
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "summarize: My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxLongT5Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ decoder_outputs = decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = module.shared.variables["params"]["embedding"]
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
+ else:
+ lm_logits = module.lm_head(sequence_output)
+
+ return lm_logits, decoder_outputs
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ if past_key_values is None:
+ lm_logits, decoder_outputs = outputs
+ else:
+ (lm_logits, decoder_outputs), past = outputs
+
+ if return_dict:
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+ else:
+ outputs = (lm_logits,) + decoder_outputs[1:]
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jnp.DeviceArray] = None,
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
+ encoder_outputs=None,
+ **kwargs
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ extended_attention_mask = jax.lax.dynamic_update_slice(
+ extended_attention_mask, decoder_attention_mask, (0, 0)
+ )
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ return model_kwargs
+
+
+FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs["input_ids"]).sequences
+ >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+"""
+
+
+overwrite_call_docstring(
+ FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING
+)
+append_replace_return_docstrings(
+ FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
+)
diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py
new file mode 100644
index 000000000000..abd1cb778655
--- /dev/null
+++ b/src/transformers/models/longt5/modeling_longt5.py
@@ -0,0 +1,2193 @@
+# coding=utf-8
+# Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LongT5 model."""
+
+
+import copy
+import math
+import warnings
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.utils.checkpoint import checkpoint
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ DUMMY_INPUTS,
+ DUMMY_MASK,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torch_fx_proxy,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_longt5 import LongT5Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LongT5Config"
+_TOKENIZER_FOR_DOC = "T5Tokenizer"
+_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
+
+# TODO: Update before the merge
+LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "google/long-t5-local-base",
+ "google/long-t5-local-large",
+ "google/long-t5-tglobal-base",
+ "google/long-t5-tglobal-large",
+]
+
+
+def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
+ """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
+ pad_len = -x.shape[dim] % block_len
+ # Handle cases when an empty input sequence is given
+ if not all(x.shape):
+ new_shape = list(x.shape)
+ new_shape[dim] += pad_len
+ return torch.zeros(new_shape, dtype=x.dtype)
+
+ pad = [(0, 0)] * x.ndim
+ pad[dim] = (0, pad_len)
+ pad = sum(pad[::-1], ())
+ x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
+ return x
+
+
+def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
+ """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
+ is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
+ """
+ # pad tensor to multiple of block_len
+ if x.shape[dim] % block_len != 0:
+ x = _pad_to_multiple(x, block_len, dim, pad_value=0)
+ num_blocks = x.shape[dim] // block_len
+ output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
+ # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
+ if 0 in output_shape:
+ return torch.empty(output_shape, dtype=x.dtype, device=x.device)
+ return x.reshape(output_shape)
+
+
+def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
+ """Concatenate three consecutive blocks for each input block for local attentiont.
+
+ For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
+ """
+ num_blocks = x.shape[block_dim]
+
+ pad = [(0, 0)] * x.ndim
+ pad[block_dim] = (1, 1)
+ pad = sum(pad[::-1], ())
+ # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
+ x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
+
+ blocks_list: List[torch.Tensor] = []
+ for i in range(3):
+ # We use indexing approach here:
+ # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
+ indices = [slice(0, None)] * x.ndim
+ indices[block_dim] = slice(i, i + num_blocks)
+ indices = tuple(indices)
+ blocks_list.append(x[indices])
+ # [batch_size, num_blocks, 3 * block_len, ...]
+ return torch.cat(blocks_list, dim=sequence_dim)
+
+
+def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
+ """Makes 3-blocked relative position ids for local attention."""
+ position_ids = torch.arange(3 * block_len, dtype=torch.int32)
+ center_position_ids = position_ids[block_len:-block_len]
+ # [block_len, 3 * block_len]
+ relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
+ return relative_position_ids
+
+
+def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
+ """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
+ relative_position_ids = _make_3block_relative_position_ids(block_len)
+ locality_mask = torch.abs(relative_position_ids) < block_len
+ locality_mask = locality_mask[None, None, :, :]
+ locality_mask = locality_mask.to(local_attention_mask.device)
+ return torch.logical_and(local_attention_mask, locality_mask)
+
+
+def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
+ """Prepare attention mask to be applied for a local attention."""
+ # [batch_size, num_blocks, block_len]
+ _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
+ # [batch_size, num_block, 3 * block_len]
+ _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
+
+ _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
+ _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
+ # [batch_size, num_block, block_len, 3 * block_len]
+ local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
+ local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
+ # [batch_size, 1, num_block, block_len, 3 * block_len]
+ return local_attention_mask.unsqueeze(1).to(device)
+
+
+def _make_global_fixed_block_ids(
+ attention_mask: torch.Tensor, global_block_size: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Obtain the "fixed block" global id corresponding to each input token.
+
+ This implementation is a simlified version of the original Flaxformr implementation adopted from:
+ https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
+
+ In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
+ the whole fixed block, are assigned to the preceding block.
+
+ Padding tokens from the original sequence are represented by -1.
+ """
+ batch_size, seq_len = attention_mask.shape[:2]
+
+ def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
+ block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
+ block_ends = block_ends.to(block_ids.device)
+ true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
+ full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
+ block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
+ return block_ids
+
+ fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
+ fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
+ mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
+ global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
+ _global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
+ global_block_ids = torch.where(
+ global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
+ )
+ # set padding tokens to -1
+ global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
+ # [batch_size, seq_len]
+ global_block_ids = handle_orphan_tokens(global_block_ids)
+ num_globals = seq_len // global_block_size
+ # [batch_size, seq_len // global_block_size]
+ if num_globals > 0:
+ _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
+ else:
+ _sequence_block_ids_max = torch.zeros(
+ batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
+ )
+ global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
+ global_segment_ids = global_segment_ids.to(attention_mask.device)
+ global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
+ return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
+
+
+def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
+ """Create the relative position tensor for local -> global attention."""
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
+ global_seq_len = global_segment_ids.shape[-1]
+ global_positions = torch.arange(global_seq_len, device=block_ids.device)
+ side_relative_position = global_positions - block_ids[..., None]
+ return side_relative_position.type(torch.int64)
+
+
+def _create_global_aggregates(
+ hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
+) -> torch.Tensor:
+ """Compute individual block aggregates by summing over individual blocks."""
+ # (batch..., seq_len, global_seq_len))
+ block_ids = block_ids.where(
+ block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
+ )
+ one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
+ return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
+class LongT5LayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+
+ # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+try:
+ from apex.normalization import FusedRMSNorm
+
+ LongT5LayerNorm = FusedRMSNorm # noqa
+
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
+except ImportError:
+ # using the normal LongT5LayerNorm
+ pass
+except Exception:
+ logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
+ pass
+
+ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
+
+
+# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
+class LongT5DenseActDense(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.act = ACT2FN[config.dense_act_fn]
+
+ def forward(self, hidden_states):
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->LongT5
+class LongT5DenseGatedActDense(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.act = ACT2FN[config.dense_act_fn]
+
+ def forward(self, hidden_states):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
+class LongT5LayerFF(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ if config.is_gated_act:
+ self.DenseReluDense = LongT5DenseGatedActDense(config)
+ else:
+ self.DenseReluDense = LongT5DenseActDense(config)
+
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, hidden_states):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states + self.dropout(forwarded_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
+class LongT5Attention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias=False):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length, device=None):
+ """Compute binned relative position bias"""
+ if device is None:
+ device = self.relative_attention_bias.weight.device
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
+ relative_position = memory_position - context_position # shape (query_length, key_length)
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # shape (query_length, key_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ assert (
+ len(past_key_value) == 2
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ def unshape(states):
+ """reshape"""
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
+ )
+ value_states = project(
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
+ )
+
+ # compute scores
+ scores = torch.matmul(
+ query_states, key_states.transpose(3, 2)
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+
+ if mask is not None:
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
+
+ scores += position_bias
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
+ scores
+ ) # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ ) # (batch_size, n_heads, seq_length, key_length)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class LongT5LocalAttention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = torch.arange(
+ 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
+ )
+ context_position = memory_position[block_length:-block_length]
+
+ # (block_length, 3 * block_length)
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # (block_length, 3 * block_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (block_length, 3 * block_length, num_heads)
+ values = self.relative_attention_bias(relative_position_bucket)
+ # (1, 1, num_heads, block_length, 3 * block_length)
+ values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.contiguous().view(batch_size, -1, self.inner_dim)
+
+ # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
+ query_states = shape(self.q(hidden_states))
+ key_states = shape(self.k(hidden_states))
+ value_states = shape(self.v(hidden_states))
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
+ query_states = _split_into_blocks(query_states, self.block_len, dim=1)
+ key_states = _split_into_blocks(key_states, self.block_len, dim=1)
+ value_states = _split_into_blocks(value_states, self.block_len, dim=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
+ value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
+
+ # Compute scores
+ scores = torch.einsum(
+ "...qhd,...khd->...hqk", query_states, key_states
+ ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
+
+ if position_bias is None:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(self.block_len)
+
+ if mask is not None:
+ # Replace masked positions with -1e10 (according to the original implementation)
+ mask = torch.where(mask > 0, 0.0, -1e10)
+ # We need to adjust position bias shape to be sum with mask
+ position_bias = position_bias + mask.transpose(1, 2)
+
+ scores += position_bias
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+ attn_weights = attn_weights.type(value_states.dtype)
+ attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
+ attn_output = attn_output[:, :seq_length, :]
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class LongT5TransientGlobalAttention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+ self.global_block_size = config.global_block_size
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ # Relativen attention bias & Layer norm for global attention
+ if self.has_relative_attention_bias:
+ self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = torch.arange(
+ 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
+ )
+ context_position = memory_position[block_length:-block_length]
+
+ # (block_length, 3 * block_length)
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # (block_length, 3 * block_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (block_length, 3 * block_length, num_heads)
+ values = self.relative_attention_bias(relative_position_bucket)
+ # (1, 1, num_heads, block_length, 3 * block_length)
+ values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
+ return values
+
+ def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
+ # (batch_size, 1, seq_len, global_seq_len)
+ side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
+ attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
+ # (batch_size, seq_len, global_seq_len)
+ side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
+ side_relative_position_bucket = self._relative_position_bucket(
+ side_relative_position,
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (batch_size, seq_len, global_seq_len, num_heads)
+ side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
+
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ side_bias = side_bias.permute([0, 3, 1, 2])
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ attention_side_bias = attention_side_bias + side_bias
+ return attention_side_bias
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.contiguous().view(batch_size, -1, self.inner_dim)
+
+ # Prepare components for transient-global attention
+ # Obtain block_ids and global_segment_ids
+ # global_seq_len := seq_len // self.global_block_size
+ # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(
+ mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
+ self.global_block_size,
+ )
+ # Create global inputs
+ _global_seq_len = global_segment_ids.shape[-1]
+ global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
+ global_inputs = self.global_input_layer_norm(global_inputs)
+
+ # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
+ query_states = shape(self.q(hidden_states))
+ key_states = shape(self.k(hidden_states))
+ value_states = shape(self.v(hidden_states))
+ # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
+ side_key_states = shape(self.k(global_inputs))
+ side_value_states = shape(self.v(global_inputs))
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
+ query_states = _split_into_blocks(query_states, self.block_len, dim=1)
+ key_states = _split_into_blocks(key_states, self.block_len, dim=1)
+ value_states = _split_into_blocks(value_states, self.block_len, dim=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
+ value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
+
+ # Tile side inputs across local key/value blocks
+ # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
+ reps = [1] * (side_key_states.ndim + 1)
+ reps[1] = key_states.shape[1]
+ side_key_states = side_key_states.unsqueeze(1).repeat(reps)
+ side_value_states = side_value_states.unsqueeze(1).repeat(reps)
+
+ # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
+ # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
+ key_states = torch.cat([key_states, side_key_states], dim=2)
+ value_states = torch.cat([value_states, side_value_states], dim=2)
+
+ # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
+ scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
+
+ if mask is not None:
+ # We need to adjust position bias shape to be sum with mask
+ local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
+ # Replace masked positions with -10_000 (according to the original implementation)
+ local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
+ else:
+ local_attention_mask = None
+
+ if position_bias is None:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
+ device=scores.device,
+ dtype=scores.dtype,
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(self.block_len)
+
+ if local_attention_mask is not None:
+ # (batch_size, 1, n_heads, block_len, 3 * block_len)
+ position_bias = position_bias + local_attention_mask.transpose(1, 2)
+ position_bias = position_bias.type(scores.dtype)
+
+ # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
+ if mask is None:
+ mask = torch.ones(batch_size, seq_length)
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ side_position_bias = self.compute_side_bias(mask, global_segment_ids)
+ # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
+ side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
+ side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
+ # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
+ position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
+
+ scores += position_bias
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+ attn_weights = attn_weights.type(value_states.dtype)
+ attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
+ attn_output = attn_output[:, :seq_length, :]
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
+class LongT5LayerSelfAttention(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5LayerLocalSelfAttention(nn.Module):
+ """Local self attention used in encoder"""
+
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ **kwargs: Any, # to accept past_key_value and use_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.LocalSelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5LayerTransientGlobalSelfAttention(nn.Module):
+ """Transient-Global self attention used in encoder"""
+
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
+ config, has_relative_attention_bias=has_relative_attention_bias
+ )
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ **kwargs: Any, # to accept past_key_value and use_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.TransientGlobalSelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
+class LongT5LayerCrossAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ )
+ layer_output = hidden_states + self.dropout(attention_output[0])
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5Block(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ if config.is_decoder:
+ attention_layer = LongT5LayerSelfAttention
+ elif config.encoder_attention_type == "local":
+ attention_layer = LongT5LayerLocalSelfAttention
+ elif config.encoder_attention_type == "transient-global":
+ attention_layer = LongT5LayerTransientGlobalSelfAttention
+ else:
+ raise ValueError(
+ "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
+ f"but got {config.encoder_attention_type}."
+ )
+ self.layer = nn.ModuleList()
+ self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias))
+ if self.is_decoder:
+ self.layer.append(LongT5LayerCrossAttention(config))
+
+ self.layer.append(LongT5LayerFF(config))
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+
+ if past_key_value is not None:
+ if not self.is_decoder:
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+
+ if len(past_key_value) != expected_num_past_key_values:
+ raise ValueError(
+ f"There should be {expected_num_past_key_values} past states. "
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+ f"Got {len(past_key_value)} past key / value states"
+ )
+
+ self_attn_past_key_value = past_key_value[:2]
+ cross_attn_past_key_value = past_key_value[2:]
+ else:
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=self_attn_past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
+
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+ if present_key_value_state is not None:
+ query_length = present_key_value_state[0].shape[2]
+ else:
+ query_length = None
+
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ query_length=query_length,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = cross_attention_outputs[0]
+
+ # Combine self attn and cross attn key value states
+ if present_key_value_state is not None:
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
+
+ # Keep cross-attention outputs and relative position weights
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+
+ outputs = (hidden_states,)
+
+ if use_cache:
+ outputs = outputs + (present_key_value_state,) + attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+
+
+class LongT5PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LongT5Config
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+
+ @property
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
+ def dummy_inputs(self):
+ input_ids = torch.tensor(DUMMY_INPUTS)
+ input_mask = torch.tensor(DUMMY_MASK)
+ dummy_inputs = {
+ "decoder_input_ids": input_ids,
+ "input_ids": input_ids,
+ "decoder_attention_mask": input_mask,
+ }
+ return dummy_inputs
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor # Used for testing weights initialization
+ if isinstance(module, LongT5LayerNorm):
+ module.weight.data.fill_(factor * 1.0)
+ elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
+ # Mesh TensorFlow embeddings initialization
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, LongT5DenseActDense):
+ # Mesh TensorFlow FF initialization
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
+ module.wi.bias.data.zero_()
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
+ module.wo.bias.data.zero_()
+ elif isinstance(module, LongT5DenseGatedActDense):
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
+ module.wi_0.bias.data.zero_()
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
+ module.wi_1.bias.data.zero_()
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
+ module.wo.bias.data.zero_()
+ elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
+ d_model = self.config.d_model
+ key_value_proj_dim = self.config.d_kv
+ n_heads = self.config.num_heads
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
+ if module.has_relative_attention_bias:
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
+ if isinstance(module, LongT5TransientGlobalAttention):
+ module.global_relative_attention_bias.weight.data.normal_(
+ mean=0.0, std=factor * ((d_model) ** -0.5)
+ )
+
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LongT5Attention, LongT5Stack)):
+ module.gradient_checkpointing = value
+
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
+ def _shift_right(self, input_ids):
+ decoder_start_token_id = self.config.decoder_start_token_id
+ pad_token_id = self.config.pad_token_id
+
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the"
+ " pad_token_id. See LongT5 docs for more information"
+ )
+
+ # shift inputs to the right
+ if is_torch_fx_proxy(input_ids):
+ # Item assignment is not supported natively for proxies.
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
+ else:
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class LongT5Stack(LongT5PreTrainedModel):
+ def __init__(self, config, embed_tokens=None):
+ super().__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
+
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+
+ self.block = nn.ModuleList(
+ [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+ )
+ self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ self.gradient_checkpointing = False
+
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+
+ if inputs_embeds is None:
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ if use_cache is True:
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
+
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
+ encoder_seq_length = encoder_hidden_states.shape[1]
+ encoder_attention_mask = torch.ones(
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
+ )
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
+ if self.is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, inputs_embeds.device
+ )
+ elif self.config.encoder_attention_type == "local":
+ extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
+ else: # we need to use both local attention mask and standard extended mask for transient-global attention
+ extended_attention_mask = attention_mask
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
+ # append next layer key value states
+ if use_cache:
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+ if self.is_decoder:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+LONGT5_START_DOCSTRING = r"""
+
+ The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
+ Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
+ Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
+ generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
+ efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LONGT5_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
+ Training](./longt5#training).
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
+ `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+LONGT5_ENCODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+__HEAD_MASK_WARNING_MSG = """
+The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
+`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
+If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
+num_heads)`.
+"""
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+class LongT5Model(LongT5PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = [
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = LongT5Stack(decoder_config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, LongT5Model
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("google/long-t5-local-base")
+ >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
+
+ >>> # Let's try a very long encoder input.
+ >>> input_ids = tokenizer(
+ ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
+
+ >>> # forward pass
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
+class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"lm_head.weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = [
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.model_dim = config.d_model
+
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = LongT5Stack(decoder_config, self.shared)
+
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
+ labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
+ >>> model = LongT5ForConditionalGeneration.from_pretrained(
+ ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
+ ... )
+
+ >>> # Let's try a very long input.
+ >>> input_ids = tokenizer(
+ ... "summarize: " + 100 * "studies have shown that owning a dog is good for you ", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+
+ >>> outputs = model.generate(input_ids)
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ abstractthe aim of this article is to summarize the studies have shown that owning a dog
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ # Convert encoder inputs in embeddings if needed
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim**-0.5)
+
+ lm_logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs
+ ):
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "decoder_input_ids": input_ids,
+ "past_key_values": past,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self._shift_right(labels)
+
+ def _reorder_cache(self, past, beam_idx):
+ # if decoder past is not included in output
+ # speedy decoding is disabled and no need to reorder
+ if past is None:
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
+ return past
+
+ reordered_decoder_past = ()
+ for layer_past_states in past:
+ # get the correct batch idx from layer past batch dim
+ # batch dim of `past` is at 2nd position
+ reordered_layer_past_states = ()
+ for layer_past_state in layer_past_states:
+ # need to set correct `past` for each of the four key / value states
+ reordered_layer_past_states = reordered_layer_past_states + (
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
+ )
+
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
+ assert len(reordered_layer_past_states) == len(layer_past_states)
+
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
+ return reordered_decoder_past
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+class LongT5EncoderModel(LongT5PreTrainedModel):
+ authorized_missing_keys = [
+ r"encoder.embed_tokens.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
+ >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
+ >>> input_ids = tokenizer(
+ ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+ >>> outputs = model(input_ids=input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return encoder_outputs
diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py
index d18d016b5026..42165923b1d8 100644
--- a/src/transformers/models/luke/__init__.py
+++ b/src/transformers/models/luke/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -26,12 +26,21 @@
"tokenization_luke": ["LukeTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_luke"] = [
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
+ "LukeForMultipleChoice",
+ "LukeForQuestionAnswering",
+ "LukeForSequenceClassification",
+ "LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
@@ -42,13 +51,22 @@
from .configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
from .tokenization_luke import LukeTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_luke import (
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
+ LukeForMultipleChoice,
+ LukeForQuestionAnswering,
+ LukeForSequenceClassification,
+ LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py
index c5a7a8f581a1..8f7438cc3c6a 100644
--- a/src/transformers/models/luke/configuration_luke.py
+++ b/src/transformers/models/luke/configuration_luke.py
@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
al.)](https://arxiv.org/abs/2010.01057).
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
Examples:
@@ -108,6 +110,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-12,
use_entity_aware_attention=True,
+ classifier_dropout=None,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
@@ -131,3 +134,4 @@ def __init__(
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_entity_aware_attention = use_entity_aware_attention
+ self.classifier_dropout = classifier_dropout
diff --git a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
index 520ae61b43ec..d2b2323b289c 100644
--- a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
@@ -77,13 +77,17 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
raise ValueError(f"Missing keys {', '.join(missing_keys)}. Expected only missing embeddings.position_ids")
if not (all(key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys)):
raise ValueError(
- f"Unexpected keys {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
+ "Unexpected keys"
+ f" {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
)
# Check outputs
tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the"
+ " new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
@@ -116,7 +120,8 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
if not (outputs.entity_last_hidden_state.shape != expected_shape):
raise ValueError(
- f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
+ f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
+ f" {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
@@ -129,7 +134,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
def load_entity_vocab(entity_vocab_path):
entity_vocab = {}
with open(entity_vocab_path, "r", encoding="utf-8") as f:
- for (index, line) in enumerate(f):
+ for index, line in enumerate(f):
title, _ = line.rstrip().split("\t")
entity_vocab[title] = index
diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py
index 7388e2031bab..6d40dfafe8e4 100644
--- a/src/transformers/models/luke/modeling_luke.py
+++ b/src/transformers/models/luke/modeling_luke.py
@@ -21,6 +21,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
@@ -28,6 +29,7 @@
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
ModelOutput,
+ add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
+@dataclass
+class LukeSequenceClassifierOutput(ModelOutput):
+ """
+ Outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
+ layer plus the initial entity embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class LukeTokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
+ layer plus the initial entity embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class LukeQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Outputs of question answering models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
+ layer plus the initial entity embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: torch.FloatTensor = None
+ end_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class LukeMultipleChoiceModelOutput(ModelOutput):
+ """
+ Outputs of multiple choice models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
+ layer plus the initial entity embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
class LukeEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
@@ -874,7 +1017,8 @@ def _set_gradient_checkpointing(self, module, value=False):
@add_start_docstrings(
- "The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any specific head on top.",
+ "The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any"
+ " specific head on top.",
LUKE_START_DOCSTRING,
)
class LukeModel(LukePreTrainedModel):
@@ -1075,7 +1219,7 @@ def get_extended_attention_mask(
raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
return extended_attention_mask
@@ -1228,24 +1372,31 @@ def forward(
loss = mlm_loss
mep_loss = None
- entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
- if entity_labels is not None:
- mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
- if loss is None:
- loss = mep_loss
- else:
- loss = loss + mep_loss
+ entity_logits = None
+ if outputs.entity_last_hidden_state is not None:
+ entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
+ if entity_labels is not None:
+ mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
+ if loss is None:
+ loss = mep_loss
+ else:
+ loss = loss + mep_loss
if not return_dict:
- output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
- if mlm_loss is not None and mep_loss is not None:
- return (loss, mlm_loss, mep_loss) + output
- elif mlm_loss is not None:
- return (loss, mlm_loss) + output
- elif mep_loss is not None:
- return (loss, mep_loss) + output
- else:
- return output
+ return tuple(
+ v
+ for v in [
+ loss,
+ mlm_loss,
+ mep_loss,
+ logits,
+ entity_logits,
+ outputs.hidden_states,
+ outputs.entity_hidden_states,
+ outputs.attentions,
+ ]
+ if v is not None
+ )
return LukeMaskedLMOutput(
loss=loss,
@@ -1357,13 +1508,11 @@ def forward(
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
- output = (
- logits,
- outputs.hidden_states,
- outputs.entity_hidden_states,
- outputs.attentions,
+ return tuple(
+ v
+ for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
+ if v is not None
)
- return ((loss,) + output) if loss is not None else output
return EntityClassificationOutput(
loss=loss,
@@ -1477,13 +1626,11 @@ def forward(
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
- output = (
- logits,
- outputs.hidden_states,
- outputs.entity_hidden_states,
- outputs.attentions,
+ return tuple(
+ v
+ for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
+ if v is not None
)
- return ((loss,) + output) if loss is not None else output
return EntityPairClassificationOutput(
loss=loss,
@@ -1617,13 +1764,11 @@ def forward(
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
- output = (
- logits,
- outputs.hidden_states,
- outputs.entity_hidden_states,
- outputs.attentions,
+ return tuple(
+ v
+ for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
+ if v is not None
)
- return ((loss,) + output) if loss is not None else output
return EntitySpanClassificationOutput(
loss=loss,
@@ -1632,3 +1777,460 @@ def forward(
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
+
+
+@add_start_docstrings(
+ """
+ The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ LUKE_START_DOCSTRING,
+)
+class LukeForSequenceClassification(LukePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.luke = LukeModel(config)
+ self.dropout = nn.Dropout(
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=LukeSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ entity_ids: Optional[torch.LongTensor] = None,
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
+ entity_position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, LukeSequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.luke(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ pooled_output = outputs.pooler_output
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
+ if v is not None
+ )
+
+ return LukeSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ entity_hidden_states=outputs.entity_hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
+ solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
+ class.
+ """,
+ LUKE_START_DOCSTRING,
+)
+class LukeForTokenClassification(LukePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.luke = LukeModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=LukeTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ entity_ids: Optional[torch.LongTensor] = None,
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
+ entity_position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, LukeTokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.luke(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
+ if v is not None
+ )
+
+ return LukeTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ entity_hidden_states=outputs.entity_hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ LUKE_START_DOCSTRING,
+)
+class LukeForQuestionAnswering(LukePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+
+ self.luke = LukeModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=LukeQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.FloatTensor] = None,
+ entity_ids: Optional[torch.LongTensor] = None,
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
+ entity_position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.luke(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions.clamp_(0, ignored_index)
+ end_positions.clamp_(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ total_loss,
+ start_logits,
+ end_logits,
+ outputs.hidden_states,
+ outputs.entity_hidden_states,
+ outputs.attentions,
+ ]
+ if v is not None
+ )
+
+ return LukeQuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ entity_hidden_states=outputs.entity_hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ LUKE_START_DOCSTRING,
+)
+class LukeForMultipleChoice(LukePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.luke = LukeModel(config)
+ self.dropout = nn.Dropout(
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=LukeMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ entity_ids: Optional[torch.LongTensor] = None,
+ entity_attention_mask: Optional[torch.FloatTensor] = None,
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
+ entity_position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, LukeMultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
+ entity_attention_mask = (
+ entity_attention_mask.view(-1, entity_attention_mask.size(-1))
+ if entity_attention_mask is not None
+ else None
+ )
+ entity_token_type_ids = (
+ entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
+ if entity_token_type_ids is not None
+ else None
+ )
+ entity_position_ids = (
+ entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
+ if entity_position_ids is not None
+ else None
+ )
+
+ outputs = self.luke(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ pooled_output = outputs.pooler_output
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss,
+ reshaped_logits,
+ outputs.hidden_states,
+ outputs.entity_hidden_states,
+ outputs.attentions,
+ ]
+ if v is not None
+ )
+
+ return LukeMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ entity_hidden_states=outputs.entity_hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py
index e75fda42ca83..3cbc9218c0f9 100644
--- a/src/transformers/models/luke/tokenization_luke.py
+++ b/src/transformers/models/luke/tokenization_luke.py
@@ -253,7 +253,8 @@ def __init__(
self.max_entity_length = 2
else:
raise ValueError(
- f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only."
+ f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification',"
+ " 'entity_span_classification'] only."
)
self.max_mention_length = max_mention_length
@@ -598,7 +599,7 @@ def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spa
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
- "entity_spans should be given as a list of tuples " "containing the start and end character indices"
+ "entity_spans should be given as a list of tuples containing the start and end character indices"
)
if entities is not None:
@@ -1007,7 +1008,8 @@ def prepare_for_model(
if num_invalid_entities != 0:
logger.warning(
- f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens"
+ f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
+ " truncation of input tokens"
)
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
@@ -1032,7 +1034,7 @@ def prepare_for_model(
entity_position_ids = []
entity_start_positions = []
entity_end_positions = []
- for (token_spans, offset) in (
+ for token_spans, offset in (
(valid_entity_token_spans, entity_token_offset),
(valid_pair_entity_token_spans, pair_entity_token_offset),
):
@@ -1181,7 +1183,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -1384,6 +1386,6 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
+ f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return vocab_file, merge_file, entity_vocab_file
diff --git a/src/transformers/models/lxmert/__init__.py b/src/transformers/models/lxmert/__init__.py
index 38d9d5e67e9f..0b8b58bc9986 100644
--- a/src/transformers/models/lxmert/__init__.py
+++ b/src/transformers/models/lxmert/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_lxmert": ["LxmertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_lxmert_fast"] = ["LxmertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_lxmert"] = [
"LxmertEncoder",
"LxmertForPreTraining",
@@ -40,7 +56,12 @@
"LxmertXLayer",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_lxmert"] = [
"TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLxmertForPreTraining",
@@ -55,10 +76,20 @@
from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from .tokenization_lxmert import LxmertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_lxmert_fast import LxmertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_lxmert import (
LxmertEncoder,
LxmertForPreTraining,
@@ -69,7 +100,12 @@
LxmertXLayer,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_lxmert import (
TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLxmertForPreTraining,
diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
index 7debd71af3b3..f8eb86f1d1e4 100755
--- a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained model. \n"
- "This specifies the model architecture.",
+ help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py
index c9b2541251e8..6ba852afcb1b 100644
--- a/src/transformers/models/lxmert/modeling_lxmert.py
+++ b/src/transformers/models/lxmert/modeling_lxmert.py
@@ -336,7 +336,7 @@ def transpose_for_scores(self, x):
self.num_attention_heads,
self.attention_head_size,
)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
@@ -365,7 +365,7 @@ def forward(self, hidden_states, context, attention_mask=None, output_attentions
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
@@ -959,13 +959,13 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
# Process the visual attention mask
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
- extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
+ extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * torch.finfo(self.dtype).min
else:
extended_visual_attention_mask = None
@@ -1110,7 +1110,7 @@ def _resize_qa_labels(self, num_labels):
def get_qa_logit_layer(self) -> nn.Module:
"""
- Returns the the linear layer that produces question answering logits.
+ Returns the linear layer that produces question answering logits.
Returns:
`nn.Module`: A torch module mapping the question answering prediction hidden states or `None` if LXMERT
@@ -1193,7 +1193,8 @@ def forward(
if "masked_lm_labels" in kwargs:
warnings.warn(
- "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels`"
+ " instead.",
FutureWarning,
)
labels = kwargs.pop("masked_lm_labels")
@@ -1252,7 +1253,7 @@ def forward(
visual_prediction_scores = visual_prediction_scores_dict[key]
visual_loss = visual_loss_fct(
visual_prediction_scores.view(-1, output_dim),
- label.view(*label_shape),
+ label.view(label_shape),
)
if visual_loss.dim() > 1: # Regression Losses
visual_loss = visual_loss.mean(1)
@@ -1340,7 +1341,7 @@ def _resize_qa_labels(self, num_labels):
def get_qa_logit_layer(self) -> nn.Module:
"""
- Returns the the linear layer that produces question answering logits
+ Returns the linear layer that produces question answering logits
Returns:
`nn.Module`: A torch module mapping the question answering prediction hidden states. `None`: A NoneType
diff --git a/src/transformers/models/lxmert/tokenization_lxmert_fast.py b/src/transformers/models/lxmert/tokenization_lxmert_fast.py
index 9e88bc1581cb..8cfa20a9a26f 100644
--- a/src/transformers/models/lxmert/tokenization_lxmert_fast.py
+++ b/src/transformers/models/lxmert/tokenization_lxmert_fast.py
@@ -24,7 +24,9 @@
"unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt",
},
"tokenizer_file": {
- "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json",
+ "unc-nlp/lxmert-base-uncased": (
+ "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/m2m_100/__init__.py b/src/transformers/models/m2m_100/__init__.py
index 81d664d0f79b..23b7e2a46cbe 100644
--- a/src/transformers/models/m2m_100/__init__.py
+++ b/src/transformers/models/m2m_100/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_m2m_100"] = [
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
@@ -39,7 +44,12 @@
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100OnnxConfig
from .tokenization_m2m_100 import M2M100Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
diff --git a/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
index 74580bc181fe..97265fbdcf93 100644
--- a/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
+++ b/src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
@@ -44,7 +44,7 @@ def make_linear_from_emb(emb):
def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
m2m_100 = torch.load(checkpoint_path, map_location="cpu")
- args = m2m_100["args"]
+ args = m2m_100["args"] or m2m_100["cfg"]["model"]
state_dict = m2m_100["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
@@ -69,7 +69,7 @@ def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = M2M100ForConditionalGeneration(config)
- model.model.load_state_dict(state_dict)
+ model.model.load_state_dict(state_dict, strict=False)
model.lm_head = make_linear_from_emb(model.model.shared)
return model
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 36539736bf84..3abe593bb129 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -131,9 +131,7 @@ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Opt
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
- self.weights = nn.Parameter(emb_weights)
- self.weights.requires_grad = False
- self.weights.detach_()
+ self.register_buffer("weights", emb_weights)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
@@ -288,7 +286,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -304,7 +303,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -325,7 +325,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -565,7 +566,7 @@ def _set_gradient_checkpointing(self, module, value=False):
"""
M2M_100_GENERATION_EXAMPLE = r"""
- Translation example::
+ Translation example:
```python
>>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
@@ -645,11 +646,10 @@ def _set_gradient_checkpointing(self, module, value=False):
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids`
- you can choose to directly pass an embedded representation. This is useful if you want more control over
- how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup
- matrix.
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
@@ -793,7 +793,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
@@ -950,8 +951,8 @@ def forward(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
- of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
@@ -994,7 +995,7 @@ def forward(
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1025,7 +1026,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
@@ -1046,7 +1048,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
+ " `use_cache=False`..."
)
use_cache = False
@@ -1235,9 +1238,9 @@ def forward(
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
@@ -1299,22 +1302,7 @@ def forward(
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
-
- Example:
-
- ```python
- >>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
-
- >>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
- >>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
-
- >>> text_to_translate = "Life is like a box of chocolates"
- >>> model_inputs = tokenizer(text_to_translate, return_tensors="pt")
-
- >>> # translate to French
- >>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("fr"))
- >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True))
- ```"""
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py
index f2e9c855bf90..b67b82fb7a58 100644
--- a/src/transformers/models/m2m_100/tokenization_m2m_100.py
+++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py
@@ -14,7 +14,6 @@
"""Tokenization classes for M2M100."""
import json
import os
-from contextlib import contextmanager
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -116,10 +115,8 @@ class M2M100Tokenizer(PreTrainedTokenizer):
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en", tgt_lang="ro")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria"
- >>> model_inputs = tokenizer(src_text, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
- >>> model(**model_inputs, labels=labels) # should work
+ >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ >>> model(**model_inputs) # should work
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -346,16 +343,12 @@ def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lan
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
+ def _switch_to_input_mode(self):
self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_target_mode(self):
+ self.set_tgt_lang_special_tokens(self.tgt_lang)
+
def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
lang_token = self.get_lang_token(src_lang)
diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py
index 5971d2d5743b..eaaaf290821b 100644
--- a/src/transformers/models/marian/__init__.py
+++ b/src/transformers/models/marian/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -31,10 +32,20 @@
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_marian"] = ["MarianTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_marian"] = [
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
"MarianForCausalLM",
@@ -43,19 +54,39 @@
"MarianPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_marian import MarianTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_marian import (
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST,
MarianForCausalLM,
@@ -64,10 +95,20 @@
MarianPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
else:
diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py
index 835b317f9d92..f662d388448b 100644
--- a/src/transformers/models/marian/configuration_marian.py
+++ b/src/transformers/models/marian/configuration_marian.py
@@ -327,8 +327,9 @@ def _generate_dummy_inputs_for_causal_lm(
self._config.hidden_size // num_encoder_attention_heads,
)
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py
index bd8490cb2d62..1fb5a34f064f 100644
--- a/src/transformers/models/marian/convert_marian_to_pytorch.py
+++ b/src/transformers/models/marian/convert_marian_to_pytorch.py
@@ -140,17 +140,21 @@ def find_model_file(dest_dir): # this one better
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
- "opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
- "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
- "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
+ "opus-mt-en-ROMANCE": (
+ "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
+ "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
+ "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
+ ),
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
- "opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
- "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
- "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
+ "opus-mt-ROMANCE-en": (
+ "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
+ "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
+ "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en"
+ ),
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py
index 8fea39e19aeb..da2e4a1fe5b5 100644
--- a/src/transformers/models/marian/modeling_flax_marian.py
+++ b/src/transformers/models/marian/modeling_flax_marian.py
@@ -551,7 +551,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index 65a471d6417c..26dc6b12dc9f 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -853,11 +856,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
@@ -993,9 +998,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
@@ -1268,9 +1274,9 @@ class MarianMTModel(MarianPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
r"embed_positions",
]
diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py
index 04a24ac9f9f1..0c2a0334dbae 100644
--- a/src/transformers/models/marian/modeling_tf_marian.py
+++ b/src/transformers/models/marian/modeling_tf_marian.py
@@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
- bsz, tgt_len = input_ids_shape
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -162,12 +163,14 @@ def _init_weight(n_pos: int, dim: int):
tf.stop_gradient(table)
return table
- def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
+ def call(
+ self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
+ ):
"""Input is expected to be of size [bsz x seqlen]."""
- bsz, seq_len = input_shape[:2]
-
- positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
- return tf.gather(self.weight, positions)
+ if position_ids is None:
+ seq_len = input_shape[1]
+ position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
+ return tf.gather(self.weight, position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian
@@ -267,7 +270,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -277,7 +283,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -293,7 +302,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -310,7 +322,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -616,6 +631,9 @@ def serving(self, inputs):
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+ decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@@ -784,7 +802,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -855,6 +876,7 @@ def call(
input_ids=None,
inputs_embeds=None,
attention_mask=None,
+ position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@@ -883,6 +905,9 @@ def call(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@@ -912,11 +937,11 @@ def call(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
- `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
- control over how to convert `input_ids` indices into associated vectors than the model's internal
- embedding lookup matrix.
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids`
+ you can choose to directly pass an embedded representation. This is useful if you want more control
+ over how to convert `input_ids` indices into associated vectors than the model's internal embedding
+ lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
@@ -945,7 +970,10 @@ def call(
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
- positions = self.embed_positions(input_shape, past_key_values_length)
+ if position_ids is None:
+ positions = self.embed_positions(input_shape, past_key_values_length)
+ else:
+ positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@@ -983,7 +1011,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1073,6 +1104,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1120,6 +1152,7 @@ def call(
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@@ -1178,6 +1211,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1197,6 +1231,7 @@ def call(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1248,7 +1283,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMarianMainLayer(config, name="model")
self.use_cache = config.use_cache
- # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
@@ -1281,6 +1316,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1323,6 +1359,7 @@ def call(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1380,6 +1417,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
+ decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1387,16 +1425,26 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):
+
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
+ if decoder_attention_mask is not None: # xla
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+ elif past is not None: # no xla + past
+ decoder_position_ids = past[0][0].shape[2]
+ else: # no xla + no past
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py
index 3579d5dffa18..66eb5a44c5bf 100644
--- a/src/transformers/models/marian/tokenization_marian.py
+++ b/src/transformers/models/marian/tokenization_marian.py
@@ -15,7 +15,6 @@
import os
import re
import warnings
-from contextlib import contextmanager
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -47,7 +46,9 @@
"Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json"
},
"tokenizer_config_file": {
- "Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/tokenizer_config.json"
+ "Helsinki-NLP/opus-mt-en-de": (
+ "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/tokenizer_config.json"
+ )
},
}
@@ -110,10 +111,7 @@ class MarianTokenizer(PreTrainedTokenizer):
>>> tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> src_texts = ["I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
- >>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
- >>> inputs["labels"] = labels["input_ids"]
+ >>> inputs = tokenizer(src_texts, text_target=tgt_texts, return_tensors="pt", padding=True)
# keys [input_ids, attention_mask, labels].
>>> outputs = model(**inputs) # should work
@@ -279,18 +277,14 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> Lis
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
+ def _switch_to_input_mode(self):
+ self.current_spm = self.spm_source
+ self.current_encoder = self.encoder
+
+ def _switch_to_target_mode(self):
self.current_spm = self.spm_target
if self.separate_vocabs:
self.current_encoder = self.target_encoder
- yield
- self.current_spm = self.spm_source
- self.current_encoder = self.encoder
@property
def vocab_size(self) -> int:
diff --git a/src/transformers/models/maskformer/__init__.py b/src/transformers/models/maskformer/__init__.py
index 2f15ed34f0c2..4234f76dc565 100644
--- a/src/transformers/models/maskformer/__init__.py
+++ b/src/transformers/models/maskformer/__init__.py
@@ -17,18 +17,26 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
-}
+_import_structure = {"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_maskformer"] = [
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
@@ -39,9 +47,19 @@
if TYPE_CHECKING:
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_maskformer import MaskFormerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
diff --git a/src/transformers/models/maskformer/configuration_maskformer.py b/src/transformers/models/maskformer/configuration_maskformer.py
index 50ad6880adb2..ab68de3f0453 100644
--- a/src/transformers/models/maskformer/configuration_maskformer.py
+++ b/src/transformers/models/maskformer/configuration_maskformer.py
@@ -24,7 +24,9 @@
MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/maskformer-swin-base-ade": "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
+ "facebook/maskformer-swin-base-ade": (
+ "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
+ )
# See all MaskFormer models at https://huggingface.co/models?filter=maskformer
}
@@ -130,7 +132,8 @@ def __init__(
backbone_model_type = backbone_config.pop("model_type")
if backbone_model_type not in self.backbones_supported:
raise ValueError(
- f"Backbone {backbone_model_type} not supported, please use one of {','.join(self.backbones_supported)}"
+ f"Backbone {backbone_model_type} not supported, please use one of"
+ f" {','.join(self.backbones_supported)}"
)
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
@@ -141,7 +144,8 @@ def __init__(
decoder_type = decoder_config.pop("model_type")
if decoder_type not in self.decoders_supported:
raise ValueError(
- f"Transformer Decoder {decoder_type} not supported, please use one of {','.join(self.decoders_supported)}"
+ f"Transformer Decoder {decoder_type} not supported, please use one of"
+ f" {','.join(self.decoders_supported)}"
)
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
diff --git a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
index 045d2bc0f515..c08591e044db 100644
--- a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
@@ -188,7 +188,7 @@ def __init__(self, original_model: nn.Module, config: MaskFormerConfig):
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
- for (src_key, dst_key) in renamed_keys:
+ for src_key, dst_key in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
def replace_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: MaskFormerConfig):
@@ -643,12 +643,18 @@ def get_name(checkpoint_file: Path):
parser.add_argument(
"--checkpoints_dir",
type=Path,
- help="A directory containing the model's checkpoints. The directory has to have the following structure: //.pkl",
+ help=(
+ "A directory containing the model's checkpoints. The directory has to have the following structure:"
+ " //.pkl"
+ ),
)
parser.add_argument(
"--configs_dir",
type=Path,
- help="A directory containing the model's configs, see detectron2 doc. The directory has to have the following structure: //.yaml",
+ help=(
+ "A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
+ " structure: //.yaml"
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
@@ -660,7 +666,10 @@ def get_name(checkpoint_file: Path):
"--maskformer_dir",
required=True,
type=Path,
- help="A path to MaskFormer's original implementation directory. You can download from here: https://github.com/facebookresearch/MaskFormer",
+ help=(
+ "A path to MaskFormer's original implementation directory. You can download from here:"
+ " https://github.com/facebookresearch/MaskFormer"
+ ),
)
args = parser.parse_args()
diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py
index 5e466f2ddb07..3a5fd49d80fa 100644
--- a/src/transformers/models/maskformer/feature_extraction_maskformer.py
+++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py
@@ -253,8 +253,9 @@ def __call__(
if not valid_segmentation_maps:
raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
+ " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
+ " examples)."
)
is_batched = bool(
@@ -591,7 +592,7 @@ def post_process_panoptic_segmentation(
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
# now, we need to iterate over the batch size to correctly process the segmentation we got from the queries using our thresholds. Even if the original predicted masks have the same shape across the batch, they won't after thresholding so batch-wise operations are impossible
results: List[Dict[str, Tensor]] = []
- for (mask_probs, pred_scores, pred_labels) in zip(mask_probs, pred_scores, pred_labels):
+ for mask_probs, pred_scores, pred_labels in zip(mask_probs, pred_scores, pred_labels):
mask_probs, pred_scores, pred_labels = self.remove_low_and_no_objects(
mask_probs, pred_scores, pred_labels, object_mask_threshold, num_labels
)
diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py
index bfb020895ce9..1266dbfdad84 100644
--- a/src/transformers/models/maskformer/modeling_maskformer.py
+++ b/src/transformers/models/maskformer/modeling_maskformer.py
@@ -471,13 +471,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
return loss / height_and_width
-# Copied from transformers.models.vit.modeling_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
@@ -496,7 +489,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
- batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
@@ -506,15 +499,21 @@ def window_reverse(windows, window_size, height, width):
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return input * random_tensor
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
class MaskFormerSwinEmbeddings(nn.Module):
@@ -525,12 +524,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
- self.patch_embeddings = MaskFormerSwinPatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.embed_dim,
- )
+ self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
@@ -559,17 +553,21 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
Image to Patch Embedding, including padding.
"""
- def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
@@ -581,7 +579,11 @@ def maybe_pad(self, pixel_values, height, width):
return pixel_values
def forward(self, pixel_values):
- _, _, height, width = pixel_values.shape
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
@@ -649,13 +651,15 @@ def forward(self, input_feature, input_dimensions):
class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None, scale_by_keep=True):
- super(MaskFormerSwinDropPath, self).__init__()
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
- def forward(self, input):
- return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
@@ -664,13 +668,16 @@ def __init__(self, config, dim, num_heads):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
- f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})"
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.window_size = to_2tuple(config.window_size)
+ window_size = config.window_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
@@ -697,7 +704,7 @@ def __init__(self, config, dim, num_heads):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -750,7 +757,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -1194,7 +1201,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
@@ -1258,7 +1266,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -1287,7 +1296,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -1902,7 +1912,7 @@ def forward(
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
- Computes the average number of target masks accross the batch, for normalization purposes.
+ Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
@@ -1955,7 +1965,7 @@ def outputs_shapes(self) -> List[int]:
return [layer.dim for layer in self.model.encoder.layers]
-class MaskFormerFPNConvLayer(nn.Sequential):
+class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
A basic module that executes conv - norm - in sequence used in MaskFormer.
@@ -1966,11 +1976,26 @@ def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, pa
out_features (`int`):
The number of outputs features (channels).
"""
- super().__init__(
+ super().__init__()
+ self.layers = [
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True),
- )
+ ]
+ for i, layer in enumerate(self.layers):
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
class MaskFormerFPNLayer(nn.Module):
@@ -2098,7 +2123,22 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
return pos
-class MaskformerMLPPredictionHead(nn.Sequential):
+class PredictionBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
+ super().__init__()
+ self.layers = [nn.Linear(in_dim, out_dim), activation]
+ # Maintain submodule indexing as if part of a Sequential block
+ for i, layer in enumerate(self.layers):
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class MaskformerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
@@ -2113,18 +2153,28 @@ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers:
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
+ super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
- layers = []
+ self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
-
- layer = nn.Sequential(
- nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity()
- )
- layers.append(layer)
-
- super().__init__(*layers)
+ activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
+ layer = PredictionBlock(in_dim, out_dim, activation=activation)
+ self.layers.append(layer)
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
class MaskFormerPixelLevelModule(nn.Module):
@@ -2250,20 +2300,21 @@ def _init_weights(self, module: nn.Module):
nn.init.constant_(module.input_projection.bias, 0)
# FPN
elif isinstance(module, MaskFormerFPNModel):
- nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std)
+ nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNLayer):
nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNConvLayer):
- nn.init.xavier_uniform_(module[0].weight, gain=xavier_std)
+ nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std)
# The MLP head
elif isinstance(module, MaskformerMLPPredictionHead):
# I was not able to find the correct initializer in the original implementation
# we'll use xavier
- for layer in module:
- nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std)
- nn.init.constant_(layer[0].bias, 0)
+ for submodule in module.modules():
+ if isinstance(submodule, nn.Linear):
+ nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
+ nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py
index 294eb15f0366..ef967c2482a1 100644
--- a/src/transformers/models/mbart/__init__.py
+++ b/src/transformers/models/mbart/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
-}
+_import_structure = {"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mbart"] = [
"MBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"MBartForCausalLM",
@@ -48,14 +62,24 @@
"MBartPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_mbart"] = [
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
@@ -68,13 +92,28 @@
if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart import MBartTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart_fast import MBartTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mbart import (
MBART_PRETRAINED_MODEL_ARCHIVE_LIST,
MBartForCausalLM,
@@ -85,10 +124,20 @@
MBartPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py
index e4da61442d16..af67cf858db1 100644
--- a/src/transformers/models/mbart/configuration_mbart.py
+++ b/src/transformers/models/mbart/configuration_mbart.py
@@ -322,8 +322,9 @@ def _generate_dummy_inputs_for_causal_lm(
self._config.hidden_size // num_encoder_attention_heads,
)
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py
index 141d2b10415e..7cb52033b78a 100644
--- a/src/transformers/models/mbart/modeling_flax_mbart.py
+++ b/src/transformers/models/mbart/modeling_flax_mbart.py
@@ -550,7 +550,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 78d094922ba1..16ea95bc0aed 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
@@ -236,7 +236,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -252,7 +253,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -273,7 +275,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -808,7 +811,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -905,11 +909,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
@@ -1048,7 +1054,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -1258,9 +1265,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: MBartConfig):
diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py
index b31ac1bd635d..5cb39d918d5f 100644
--- a/src/transformers/models/mbart/modeling_tf_mbart.py
+++ b/src/transformers/models/mbart/modeling_tf_mbart.py
@@ -86,7 +86,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
- bsz, tgt_len = input_ids_shape
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@@ -99,7 +100,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -124,12 +125,19 @@ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
- def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
+ def call(
+ self,
+ input_shape: Optional[tf.TensorShape] = None,
+ past_key_values_length: int = 0,
+ position_ids: Optional[tf.Tensor] = None,
+ ):
"""Input is expected to be of size [bsz x seqlen]."""
- bsz, seq_len = input_shape[:2]
+ if position_ids is None:
+ seq_len = input_shape[1]
+ position_ids = tf.range(seq_len, delta=1, name="range")
+ position_ids += past_key_values_length
- positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
- return super().call(positions + self.offset)
+ return super().call(position_ids + self.offset)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
@@ -229,7 +237,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -239,7 +250,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -255,7 +269,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -272,7 +289,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -557,6 +577,9 @@ def serving(self, inputs):
for denoising pre-training following the paper.
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+ decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@@ -763,7 +786,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -838,6 +864,7 @@ def call(
input_ids: TFModelInputType = None,
inputs_embeds: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
+ position_ids: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
@@ -868,6 +895,9 @@ def call(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@@ -897,11 +927,11 @@ def call(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
- `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
- control over how to convert `input_ids` indices into associated vectors than the model's internal
- embedding lookup matrix.
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids`
+ you can choose to directly pass an embedded representation. This is useful if you want more control
+ over how to convert `input_ids` indices into associated vectors than the model's internal embedding
+ lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
@@ -930,7 +960,10 @@ def call(
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
- positions = self.embed_positions(input_shape, past_key_values_length)
+ if position_ids is None:
+ positions = self.embed_positions(input_shape, past_key_values_length)
+ else:
+ positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@@ -969,7 +1002,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1061,6 +1097,7 @@ def call(
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
+ decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@@ -1111,6 +1148,7 @@ def call(
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@@ -1169,6 +1207,7 @@ def call(
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
+ decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@@ -1189,6 +1228,7 @@ def call(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1240,7 +1280,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMBartMainLayer(config, name="model")
self.use_cache = config.use_cache
- # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
@@ -1273,6 +1313,7 @@ def call(
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
+ decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@@ -1300,7 +1341,7 @@ def call(
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
- tf.fill(shape_list(labels), -100),
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False
@@ -1313,6 +1354,7 @@ def call(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1370,6 +1412,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
+ decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1377,16 +1420,26 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):
+
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
+ if decoder_attention_mask is not None: # xla
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+ elif past is not None: # no xla + past
+ decoder_position_ids = past[0][0].shape[2]
+ else: # no xla + no past
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py
index d6ea6260aec1..b6b4173e50af 100644
--- a/src/transformers/models/mbart/tokenization_mbart.py
+++ b/src/transformers/models/mbart/tokenization_mbart.py
@@ -14,7 +14,6 @@
# limitations under the License.
import os
-from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
@@ -32,8 +31,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model",
- "facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-en-ro": (
+ "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model"
+ ),
+ "facebook/mbart-large-cc25": (
+ "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
@@ -54,8 +57,8 @@ class MBartTokenizer(PreTrainedTokenizer):
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece).
- The tokenization method is ` ` for source language documents, and ``
- ``` for target language documents.
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
Examples:
@@ -65,10 +68,7 @@ class MBartTokenizer(PreTrainedTokenizer):
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria"
- >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
- >>> inputs["labels"] = labels["input_ids"]
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -336,15 +336,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
- self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py
index a172d37913a4..0ac14033a44a 100644
--- a/src/transformers/models/mbart/tokenization_mbart_fast.py
+++ b/src/transformers/models/mbart/tokenization_mbart_fast.py
@@ -14,7 +14,6 @@
# limitations under the License.
import os
-from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple
@@ -38,8 +37,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model",
- "facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-en-ro": (
+ "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model"
+ ),
+ "facebook/mbart-large-cc25": (
+ "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
"facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/tokenizer.json",
@@ -65,8 +68,8 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
- The tokenization method is ` ` for source language documents, and ``
- ``` for target language documents.
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
Examples:
@@ -78,10 +81,7 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
... )
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria"
- >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
- >>> inputs["labels"] = labels["input_ids"]
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -236,15 +236,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
- self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
diff --git a/src/transformers/models/mbart50/__init__.py b/src/transformers/models/mbart50/__init__.py
index ee0edc94dfb4..299c0d0da7bb 100644
--- a/src/transformers/models/mbart50/__init__.py
+++ b/src/transformers/models/mbart50/__init__.py
@@ -17,23 +17,43 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart50 import MBart50Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart50_fast import MBart50TokenizerFast
else:
diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py
index c7e53c61495b..707a97734927 100644
--- a/src/transformers/models/mbart50/tokenization_mbart50.py
+++ b/src/transformers/models/mbart50/tokenization_mbart50.py
@@ -14,7 +14,6 @@
# limitations under the License.
import os
-from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
@@ -32,7 +31,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
@@ -100,10 +101,8 @@ class MBart50Tokenizer(PreTrainedTokenizer):
>>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria"
- >>> model_inputs = tokenizer(src_text, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
- >>> # model(**model_inputs, labels=labels) should work
+ >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ >>> # model(**model_inputs) should work
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -335,15 +334,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
- self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
diff --git a/src/transformers/models/mbart50/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py
index 97e2584a0d00..1ab8ff06e260 100644
--- a/src/transformers/models/mbart50/tokenization_mbart50_fast.py
+++ b/src/transformers/models/mbart50/tokenization_mbart50_fast.py
@@ -14,7 +14,6 @@
# limitations under the License.
import os
-from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple
@@ -37,10 +36,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json"
+ ),
},
}
@@ -94,10 +97,8 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
>>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria"
- >>> model_inputs = tokenizer(src_text, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
- >>> # model(**model_inputs, labels=labels) should work
+ >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ >>> # model(**model_inputs) should work
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -207,15 +208,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
- self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
diff --git a/src/transformers/models/mctct/__init__.py b/src/transformers/models/mctct/__init__.py
new file mode 100644
index 000000000000..6c28eb2214c5
--- /dev/null
+++ b/src/transformers/models/mctct/__init__.py
@@ -0,0 +1,75 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig"],
+ "processing_mctct": ["MCTCTProcessor"],
+}
+
+
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_mctct"] = ["MCTCTFeatureExtractor"]
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_mctct"] = [
+ "MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MCTCTForCTC",
+ "MCTCTModel",
+ "MCTCTPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig
+ from .processing_mctct import MCTCTProcessor
+
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_mctct import MCTCTFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_mctct import MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST, MCTCTForCTC, MCTCTModel, MCTCTPreTrainedModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/mctct/configuration_mctct.py b/src/transformers/models/mctct/configuration_mctct.py
new file mode 100644
index 000000000000..f71467e65dae
--- /dev/null
+++ b/src/transformers/models/mctct/configuration_mctct.py
@@ -0,0 +1,185 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""M-CTC-T model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "speechbrain/m-ctc-t-large": "https://huggingface.co/speechbrain/m-ctc-t-large/resolve/main/config.json",
+ # See all M-CTC-T models at https://huggingface.co/models?filter=mctct
+}
+
+
+class MCTCTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an
+ M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the M-CTC-T
+ [speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 8065):
+ Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MCTCTModel`].
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 36):
+ Number of hidden layers in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ attention_head_dim (`int`, *optional*, defaults to 384):
+ Dimensions of each attention head for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 920):
+ The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ layerdrop (`float`, *optional*, defaults to 0.3):
+ The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
+ implementation.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ The tokenizer index of the pad token.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The tokenizer index of the bos token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The tokenizer index of the eos token.
+ conv_glu_dim (`int`, *optional*, defaults to 1):
+ The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original
+ Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences.
+ conv_dropout (`int`, *optional*, defaults to 0.3):
+ The probability of randomly dropping the `Conv1dSubsampler` layer during training.
+ num_conv_layers (`int`, *optional*, defaults to 1):
+ Number of convolution layers before applying transformer encoder layers.
+ conv_kernel (`List[int]`, *optional*, defaults to `[7]`):
+ The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
+ to `num_conv_layers`.
+ conv_stride (`List[int]`, *optional*, defaults to `[3]`):
+ The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
+ to `num_conv_layers`.
+ input_feat_per_channel (`int`, *optional*, defaults to 80):
+ Feature dimensions of the channels of the input to the Conv1D layer.
+ input_channels (`int`, *optional*, defaults to 1):
+ Number of input channels of the input to the Conv1D layer.
+ conv_channels (`List[int]`, *optional*, defaults to None):
+ Channel sizes of intermediate Conv1D layers.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`MCTCTForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`MCTCTForCTC`].
+
+ Example:
+
+ ```python
+ >>> from transformers import MCTCTModel, MCTCTConfig
+
+ >>> # Initializing a M-CTC-T mctct-large style configuration
+ >>> configuration = MCTCTConfig()
+
+ >>> # Initializing a model from the mctct-large style configuration
+ >>> model = MCTCTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "mctct"
+
+ def __init__(
+ self,
+ vocab_size=8065,
+ hidden_size=1536,
+ num_hidden_layers=36,
+ intermediate_size=6144,
+ num_attention_heads=4,
+ attention_head_dim=384,
+ max_position_embeddings=920,
+ layer_norm_eps=1e-5,
+ layerdrop=0.3,
+ hidden_act="relu",
+ initializer_range=0.02,
+ hidden_dropout_prob=0.3,
+ attention_probs_dropout_prob=0.3,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ conv_glu_dim=1,
+ conv_dropout=0.3,
+ num_conv_layers=1,
+ conv_kernel=(7,),
+ conv_stride=(3,),
+ input_feat_per_channel=80,
+ input_channels=1,
+ conv_channels=None,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ **kwargs
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.layerdrop = layerdrop
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.conv_glu_dim = conv_glu_dim
+ self.conv_dropout = conv_dropout
+ self.num_conv_layers = num_conv_layers
+ self.input_feat_per_channel = input_feat_per_channel
+ self.input_channels = input_channels
+ self.conv_channels = conv_channels
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # prevents config testing fail with exporting to json
+ self.conv_kernel = list(conv_kernel)
+ self.conv_stride = list(conv_stride)
+
+ if len(self.conv_kernel) != self.num_conv_layers:
+ raise ValueError(
+ "Configuration for convolutional module is incorrect. "
+ "It is required that `len(config.conv_kernel)` == `config.num_conv_layers` "
+ f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, "
+ f"`config.num_conv_layers = {self.num_conv_layers}`."
+ )
diff --git a/src/transformers/models/mctct/feature_extraction_mctct.py b/src/transformers/models/mctct/feature_extraction_mctct.py
new file mode 100644
index 000000000000..573551bcf778
--- /dev/null
+++ b/src/transformers/models/mctct/feature_extraction_mctct.py
@@ -0,0 +1,356 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for M-CTC-T
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+import torchaudio
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...file_utils import PaddingStrategy, TensorType
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MCTCTFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a M-CTC-T feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods. This
+ code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to
+ this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an)
+ that takes the user step-by-step in the implementation.
+
+ Args:
+ feature_size (`int`, defaults to 80):
+ The feature dimension of the extracted features. This is the number of mel_frequency
+ sampling_rate (`int`, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
+ padding_value (`float`, defaults to 0.0):
+ The value that is used to fill the padding values.
+ hop_length (`int`, defaults to 10):
+ Number of audio samples between windows. Otherwise referred to as "shift" in many papers.
+ win_length (`int`, defaults to 25):
+ Number of ms per window
+ win_function (`str`, defaults to `"hamming_window"`):
+ Name for the window function used for windowing, must be accessible via `torch.{win_function}`
+ frame_signal_scale (`float`, defaults to 32768.0):
+ Constant multiplied in creating the frames before applying DFT.
+ preemphasis_coeff (`float`, defaults to 0.97):
+ Constant multiplied in applying Pre-emphasis before DFT.
+ mel_floor (`float` defaults to 1.0):
+ Minimum value of mel frequency banks.
+ normalize_means (`bool`, *optional*, defaults to `True`):
+ Whether or not to zero-mean normalize the extracted features.
+ normalize_vars (`bool`, *optional*, defaults to `True`):
+ Whether or not to unit-variance normalize the extracted features.
+ """
+
+ model_input_names = ["input_features", "attention_mask"]
+
+ def __init__(
+ self,
+ feature_size=80,
+ sampling_rate=16000,
+ padding_value=0.0,
+ hop_length=10,
+ win_length=25,
+ win_function="hamming_window",
+ frame_signal_scale=32768.0,
+ preemphasis_coeff=0.97,
+ mel_floor=1.0,
+ normalize_means=True,
+ normalize_vars=True,
+ return_attention_mask=False,
+ **kwargs
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.frame_signal_scale = frame_signal_scale
+ self.preemphasis_coeff = preemphasis_coeff
+ self.mel_floor = mel_floor
+ self.normalize_means = normalize_means
+ self.normalize_vars = normalize_vars
+ self.win_function = win_function
+ self.return_attention_mask = return_attention_mask
+
+ self.sample_size = win_length * sampling_rate // 1000
+ self.sample_stride = hop_length * sampling_rate // 1000
+
+ self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
+ self.n_freqs = (self.n_fft // 2) + 1
+
+ @staticmethod
+ def _num_frames_calc(in_size, frame_size, frame_stride):
+ return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
+
+ @staticmethod
+ def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
+ scale = frame_signal_scale
+ frames = np.zeros(n_frames * window_length)
+ for frame_idx in range(n_frames):
+ start = frame_idx * window_length
+ end = (frame_idx + 1) * window_length
+ wave_start = frame_idx * sample_stride
+ wave_end = frame_idx * sample_stride + window_length
+ frames[start:end] = scale * one_waveform[wave_start:wave_end]
+
+ return frames
+
+ @staticmethod
+ def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
+ if frames.size % window_length != 0:
+ raise ValueError(
+ f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
+ f" window_length={window_length}."
+ )
+
+ n_frames = frames.size // window_length
+ for frame_idx in range(n_frames, 0, -1):
+ start = (frame_idx - 1) * window_length
+ end = frame_idx * window_length - 1
+ frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
+ frames[start] *= 1 - preemphasis_coeff
+
+ @staticmethod
+ def _windowing(frames, window_length, window):
+ if frames.size % window_length != 0:
+ raise ValueError(
+ f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
+ f" window_length={window_length}."
+ )
+
+ shaped = frames.reshape(-1, window_length)
+ shaped = window * shaped
+ return shaped
+
+ @staticmethod
+ def _dft(frames, K, n_frames, n_samples, n_fft):
+ dft = np.zeros([n_frames, K])
+
+ for frame in range(n_frames):
+ begin = frame * n_samples
+
+ inwards_buffer = frames[begin : begin + n_samples]
+ inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
+ out = np.fft.rfft(inwards_buffer)
+
+ dft[frame] = np.abs(out[:K])
+
+ return dft
+
+ def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
+ """
+ Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
+ """
+ if self.win_function == "hamming_window":
+ window = torch.hamming_window(window_length=self.sample_size, periodic=False, alpha=0.54, beta=0.46)
+ else:
+ window = getattr(torch, self.win_function)()
+
+ window = window.numpy()
+
+ fbanks = torchaudio.functional.melscale_fbanks(
+ n_freqs=self.n_freqs,
+ f_min=0.0, # change this to zeros
+ f_max=self.sampling_rate / 2.0,
+ n_mels=self.feature_size,
+ sample_rate=self.sampling_rate,
+ )
+
+ fbanks = fbanks.numpy()
+
+ n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)
+
+ frames = self._frame_signal(
+ one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
+ )
+
+ self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
+
+ frames = self._windowing(frames, self.sample_size, window)
+
+ dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
+
+ # msfc_features = STFT * mel frequency banks.
+ msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
+
+ # clamp feature values then log scale, as implemented in flashlight
+ msfc_features = np.maximum(msfc_features, self.mel_floor)
+ msfc_features = np.log(msfc_features)
+
+ return msfc_features
+
+ def _normalize_one(self, x, input_length, padding_value):
+ # make sure we normalize float32 arrays
+ if self.normalize_means:
+ mean = x[:input_length].mean(axis=0)
+ x = np.subtract(x, mean)
+ if self.normalize_vars:
+ std = x[:input_length].std(axis=0)
+ x = np.divide(x, std)
+
+ if input_length < x.shape[0]:
+ x[input_length:] = padding_value
+
+ # make sure array is in float32
+ x = x.astype(np.float32)
+
+ return x
+
+ def normalize(
+ self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
+ ) -> List[np.ndarray]:
+ lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
+ return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)]
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
+ padding: Union[bool, str, PaddingStrategy] = False,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the
+ log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.
+
+ Args:
+ raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
+ of float values, a list of tensors, a list of numpy arrays or a list of list of float values.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ >= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ padding_value (`float`, defaults to 0.0):
+ """
+
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ is_batched = bool(
+ isinstance(raw_speech, (list, tuple))
+ and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
+ )
+
+ if is_batched:
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ raw_speech = raw_speech.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_speech = [raw_speech]
+
+ # extract fbank features
+ features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech]
+
+ # convert into correct format for padding
+ encoded_inputs = BatchFeature({"input_features": features})
+
+ padded_inputs = self.pad(
+ encoded_inputs,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=True,
+ **kwargs,
+ )
+ # make sure list is in array format
+ input_features = padded_inputs.get("input_features")
+ if isinstance(input_features[0], list):
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
+
+ attention_mask = padded_inputs.get("attention_mask")
+ if attention_mask is not None:
+ padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
+
+ if self.normalize_means or self.normalize_vars:
+ attention_mask = (
+ np.array(attention_mask, dtype=np.int32)
+ if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
+ and padding
+ else None
+ )
+ padded_inputs["input_features"] = self.normalize(
+ padded_inputs["input_features"], attention_mask=attention_mask
+ )
+
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
diff --git a/src/transformers/models/mctct/modeling_mctct.py b/src/transformers/models/mctct/modeling_mctct.py
new file mode 100755
index 000000000000..3eb59a0c419b
--- /dev/null
+++ b/src/transformers/models/mctct/modeling_mctct.py
@@ -0,0 +1,825 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch M-CTC-T model."""
+
+
+import math
+import random
+from typing import Optional
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...deepspeed import is_deepspeed_zero3_enabled
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from ...pytorch_utils import is_torch_greater_than_1_6
+from ...utils import logging
+from .configuration_mctct import MCTCTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_HIDDEN_STATES_START_POSITION = 1
+
+_CONFIG_FOR_DOC = "MCTCTConfig"
+_PROCESSOR_FOR_DOC = "MCTCTProcessor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large"
+_EXPECTED_OUTPUT_SHAPE = [1, 195, 1536]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."'
+_CTC_EXPECTED_LOSS = 1885.65
+
+
+MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "speechbrain/m-ctc-t-large",
+ # See all M-CTC-T models at https://huggingface.co/models?filter=mctct
+]
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class MCTCTConv1dSubsampler(nn.Module):
+ """
+ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
+ via gated linear units (https://arxiv.org/abs/1911.08460)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.glu_dim = config.conv_glu_dim
+
+ self.dropout = nn.Dropout(config.conv_dropout)
+
+ self.num_layers = config.num_conv_layers
+ self.in_channels = config.input_feat_per_channel * config.input_channels
+
+ if self.num_layers > 1:
+ if config.conv_channels is None:
+ raise ValueError(
+ "Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution"
+ " layers."
+ )
+
+ self.mid_channels = config.conv_channels
+ else:
+ self.mid_channels = None
+
+ self.out_channels = config.hidden_size * 2 # considering GLU halving
+ self.kernel_size = config.conv_kernel
+ self.stride = config.conv_stride
+
+ # NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for
+ # multiple layers of convolutions, but not sure if this model definition should just restrict it
+ # to one layer. This becomes especially relevant when considering the padding like line 1 of forward().
+ self.conv_layers = nn.ModuleList(
+ nn.Conv1d(
+ self.in_channels if i == 0 else self.mid_channels[i],
+ self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels,
+ kernel_size=k,
+ stride=self.stride[i],
+ padding="valid",
+ )
+ for i, k in enumerate(self.kernel_size)
+ )
+
+ def forward(self, input_features):
+ # NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if
+ # there will be just one conv layer.
+ padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3)
+
+ input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0)
+ hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time
+ for conv in self.conv_layers:
+ hidden_states = conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame
+ return hidden_states
+
+
+class MCTCTEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.LayerNorm = MCTCTLayerNorm()
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if is_torch_greater_than_1_6:
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
+ persistent=False,
+ )
+
+ def forward(
+ self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_features)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class MCTCTSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = config.attention_head_dim
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def reshape_fortran(self, x, shape):
+ if len(x.shape) > 0:
+ x = x.permute(*reversed(range(len(x.shape))))
+ return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
+
+ def relative_position_embedding_rotate(self, scores):
+ # NOTE: should re-evaluate whether this re-implementation was truly necessary
+ # or the reason why my complete re-haul worked was due to some other part
+ # of the code. Adding this and the reshape fortrain code seems very undesirable.
+ scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4]
+
+ batch, hidden_state, seq_len, heads = scores.shape
+
+ # e.g. [10, 1853, 14, 4]
+ scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1)
+
+ # e.g. [10, 25942, 1, 4]
+ scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads])
+
+ # e.g. [10, 25928, 1, 4]
+ scores = scores[:, : (seq_len + hidden_state - 1) * seq_len]
+
+ # e.g. [10, 1852, 14, 4]
+ scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads])
+
+ halfpoint = hidden_state // 2
+ scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4]
+
+ return scores.permute(0, 3, 1, 2)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ # relative key position embeddings
+ positional_embedding = self.distance_embedding.weight
+ relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3))
+
+ relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores)
+ attention_scores = attention_scores + relative_position_scores
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class MCTCTLayerNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.singleton_weight = nn.Parameter(torch.ones(1))
+ self.singleton_bias = nn.Parameter(torch.zeros(1))
+
+ def forward(self, hidden_states):
+ return (hidden_states * self.singleton_weight) + self.singleton_bias
+
+
+class MCTCTSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = MCTCTSelfAttention(config)
+ self.output = MCTCTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+
+class MCTCTIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class MCTCTOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTLayer(nn.Module):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__()
+
+ self.seq_len_dim = 1
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+
+ self.intermediate = MCTCTIntermediate(config)
+ self.attention = MCTCTAttention(config)
+ self.is_decoder = config.is_decoder
+ self.output = MCTCTOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class MCTCTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MCTCTConfig
+ base_model_prefix = "mctct"
+ main_input_name = "input_features"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, MCTCTLayerNorm):
+ module.singleton_weight.data.fill_(1.0)
+ module.singleton_bias.data.zero_()
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+ dilation = 1
+ for _, kernel_sz, stride in zip(
+ range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride
+ ):
+ padding = kernel_sz // 2
+ input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1
+ input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
+ # generate creates 3D attention mask, because of the shape of input_features
+ # convert it to 2D if thats the case
+ if len(attention_mask.shape) > 2:
+ attention_mask = attention_mask[:, :, -1]
+
+ # subsampled_lengths = attention_mask.sum(-1)
+ subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ bsz = attention_mask.size()[0]
+ attention_mask = torch.zeros(
+ (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (MCTCTEncoder)):
+ module.gradient_checkpointing = value
+
+
+MCTCT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MCTCT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MCTCTEncoder(MCTCTPreTrainedModel):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__(config)
+ self.hidden_dropout_prob = config.hidden_dropout_prob
+
+ self.layer_norm = MCTCTLayerNorm()
+ self.conv = MCTCTConv1dSubsampler(config)
+ self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_features,
+ attention_mask,
+ head_mask,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_features = self.layer_norm(input_features)
+
+ inputs_embeds = self.conv(input_features)
+
+ # subsample attention mask if necessary
+ if attention_mask is not None:
+ attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
+
+ hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, "
+ f"but it is for {head_mask.size()[0]}."
+ )
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+@add_start_docstrings(
+ "The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTModel(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = MCTCTEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_features is None:
+ raise ValueError("You have to specify input_features.")
+
+ encoder_outputs = self.encoder(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTForCTC(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mctct = MCTCTModel(config)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = config.hidden_size
+
+ self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ labels=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.mctct(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ logits = self.ctc_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask
+ if attention_mask is not None
+ else torch.ones(input_features.shape[:-1], dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
diff --git a/src/transformers/models/mctct/processing_mctct.py b/src/transformers/models/mctct/processing_mctct.py
new file mode 100644
index 000000000000..2e05020196ac
--- /dev/null
+++ b/src/transformers/models/mctct/processing_mctct.py
@@ -0,0 +1,140 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Speech processor class for M-CTC-T
+"""
+import warnings
+from contextlib import contextmanager
+
+from ...processing_utils import ProcessorMixin
+
+
+class MCTCTProcessor(ProcessorMixin):
+ r"""
+ Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.
+
+ [`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the
+ [`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`MCTCTFeatureExtractor`):
+ An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`AutoTokenizer`):
+ An instance of [`AutoTokenizer`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "MCTCTFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def __call__(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
+ [`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def pad(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
+ [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor.pad(*args, **kwargs)
+
+ input_features = kwargs.pop("input_features", None)
+ labels = kwargs.pop("labels", None)
+ if len(args) > 0:
+ input_features = args[0]
+ args = args[1:]
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
+ """
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/megatron_bert/__init__.py b/src/transformers/models/megatron_bert/__init__.py
index d49ab274e565..9075b898377a 100644
--- a/src/transformers/models/megatron_bert/__init__.py
+++ b/src/transformers/models/megatron_bert/__init__.py
@@ -17,14 +17,19 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_megatron_bert"] = [
"MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegatronBertForCausalLM",
@@ -42,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronBertForCausalLM,
diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
index e914822736d5..371782c2976e 100755
--- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py
+++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
@@ -460,7 +460,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -1426,7 +1427,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
diff --git a/src/transformers/models/mluke/__init__.py b/src/transformers/models/mluke/__init__.py
index acd6dff11f19..b6582e35a9d0 100644
--- a/src/transformers/models/mluke/__init__.py
+++ b/src/transformers/models/mluke/__init__.py
@@ -18,17 +18,27 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available
_import_structure = {}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mluke"] = ["MLukeTokenizer"]
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mluke import MLukeTokenizer
diff --git a/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
index c75a710cee2f..9d61c3bc8e27 100644
--- a/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
@@ -153,7 +153,8 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
if not (outputs.entity_last_hidden_state.shape == expected_shape):
raise ValueError(
- f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
+ f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
+ f" {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py
index 24a6304fc145..57272c391fb3 100644
--- a/src/transformers/models/mluke/tokenization_mluke.py
+++ b/src/transformers/models/mluke/tokenization_mluke.py
@@ -342,7 +342,8 @@ def __init__(
self.max_entity_length = 2
else:
raise ValueError(
- f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only."
+ f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification',"
+ " 'entity_span_classification'] only."
)
self.max_mention_length = max_mention_length
@@ -707,7 +708,7 @@ def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spa
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
- "entity_spans should be given as a list of tuples " "containing the start and end character indices"
+ "entity_spans should be given as a list of tuples containing the start and end character indices"
)
if entities is not None:
@@ -1119,7 +1120,8 @@ def prepare_for_model(
if num_invalid_entities != 0:
logger.warning(
- f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens"
+ f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
+ " truncation of input tokens"
)
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
@@ -1144,7 +1146,7 @@ def prepare_for_model(
entity_position_ids = []
entity_start_positions = []
entity_end_positions = []
- for (token_spans, offset) in (
+ for token_spans, offset in (
(valid_entity_token_spans, entity_token_offset),
(valid_pair_entity_token_spans, pair_entity_token_offset),
):
@@ -1294,7 +1296,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -1507,7 +1509,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
+ f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return out_vocab_file, entity_vocab_file
diff --git a/src/transformers/models/mmbt/__init__.py b/src/transformers/models/mmbt/__init__.py
index 763a256f1a20..d95a2cc8d84a 100644
--- a/src/transformers/models/mmbt/__init__.py
+++ b/src/transformers/models/mmbt/__init__.py
@@ -18,21 +18,29 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_mmbt": ["MMBTConfig"],
-}
+_import_structure = {"configuration_mmbt": ["MMBTConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mmbt"] = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]
if TYPE_CHECKING:
from .configuration_mmbt import MMBTConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
else:
diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py
index 505dabe18791..ae91c38bdfb3 100644
--- a/src/transformers/models/mobilebert/__init__.py
+++ b/src/transformers/models/mobilebert/__init__.py
@@ -18,18 +18,38 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"],
+ "configuration_mobilebert": [
+ "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "MobileBertConfig",
+ "MobileBertOnnxConfig",
+ ],
"tokenization_mobilebert": ["MobileBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mobilebert_fast"] = ["MobileBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mobilebert"] = [
"MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"MobileBertForMaskedLM",
@@ -45,7 +65,12 @@
"load_tf_weights_in_mobilebert",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mobilebert"] = [
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileBertForMaskedLM",
@@ -62,13 +87,27 @@
if TYPE_CHECKING:
- from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
+ from .configuration_mobilebert import (
+ MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ MobileBertConfig,
+ MobileBertOnnxConfig,
+ )
from .tokenization_mobilebert import MobileBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mobilebert_fast import MobileBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MobileBertForMaskedLM,
@@ -84,7 +123,12 @@
load_tf_weights_in_mobilebert,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py
index 27863235b3d7..73b8844ed763 100644
--- a/src/transformers/models/mobilebert/configuration_mobilebert.py
+++ b/src/transformers/models/mobilebert/configuration_mobilebert.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MobileBERT model configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -165,3 +168,20 @@ def __init__(
self.true_hidden_size = hidden_size
self.classifier_dropout = classifier_dropout
+
+
+# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert
+class MobileBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
index 5c03331eb3d9..022a9d036cdb 100644
--- a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
@@ -46,8 +46,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file,
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained MobileBERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained MobileBERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py
index 6d2b2d3ce2e5..6bc306a6e05e 100644
--- a/src/transformers/models/mobilebert/modeling_mobilebert.py
+++ b/src/transformers/models/mobilebert/modeling_mobilebert.py
@@ -226,9 +226,9 @@ def forward(
# dimensional output.
inputs_embeds = torch.cat(
[
- nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
+ nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds,
- nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
+ nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
],
dim=2,
)
@@ -1188,7 +1188,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
diff --git a/src/transformers/models/mobilevit/__init__.py b/src/transformers/models/mobilevit/__init__.py
new file mode 100644
index 000000000000..cd639f50323c
--- /dev/null
+++ b/src/transformers/models/mobilevit/__init__.py
@@ -0,0 +1,79 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_mobilevit": ["MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTConfig", "MobileViTOnnxConfig"],
+}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_mobilevit"] = ["MobileViTFeatureExtractor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_mobilevit"] = [
+ "MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MobileViTForImageClassification",
+ "MobileViTForSemanticSegmentation",
+ "MobileViTModel",
+ "MobileViTPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig, MobileViTOnnxConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_mobilevit import MobileViTFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_mobilevit import (
+ MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MobileViTForImageClassification,
+ MobileViTForSemanticSegmentation,
+ MobileViTModel,
+ MobileViTPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/mobilevit/configuration_mobilevit.py b/src/transformers/models/mobilevit/configuration_mobilevit.py
new file mode 100644
index 000000000000..87a8a009ddc3
--- /dev/null
+++ b/src/transformers/models/mobilevit/configuration_mobilevit.py
@@ -0,0 +1,185 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" MobileViT model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "apple/mobilevit-small": "https://huggingface.co/apple/mobilevit-small/resolve/main/config.json",
+ "apple/mobilevit-x-small": "https://huggingface.co/apple/mobilevit-x-small/resolve/main/config.json",
+ "apple/mobilevit-xx-small": "https://huggingface.co/apple/mobilevit-xx-small/resolve/main/config.json",
+ "apple/deeplabv3-mobilevit-small": (
+ "https://huggingface.co/apple/deeplabv3-mobilevit-small/resolve/main/config.json"
+ ),
+ "apple/deeplabv3-mobilevit-x-small": (
+ "https://huggingface.co/apple/deeplabv3-mobilevit-x-small/resolve/main/config.json"
+ ),
+ "apple/deeplabv3-mobilevit-xx-small": (
+ "https://huggingface.co/apple/deeplabv3-mobilevit-xx-small/resolve/main/config.json"
+ ),
+ # See all MobileViT models at https://huggingface.co/models?filter=mobilevit
+}
+
+
+class MobileViTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MobileViTModel`]. It is used to instantiate a
+ MobileViT model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MobileViT
+ [apple/mobilevit-small](https://huggingface.co/apple/mobilevit-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 256):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 2):
+ The size (resolution) of each patch.
+ hidden_sizes (`List[int]`, *optional*, defaults to `[144, 192, 240]`):
+ Dimensionality (hidden size) of the Transformer encoders at each stage.
+ neck_hidden_sizes (`List[int]`, *optional*, defaults to `[16, 32, 64, 96, 128, 160, 640]`):
+ The number of channels for the feature maps of the backbone.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`float`, *optional*, defaults to 2.0):
+ The ratio of the number of channels in the output of the MLP to the number of channels in the input.
+ expand_ratio (`float`, *optional*, defaults to 4.0):
+ Expansion factor for the MobileNetv2 layers.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
+ conv_kernel_size (`int`, *optional*, defaults to 3):
+ The size of the convolutional kernel in the MobileViT layer.
+ output_stride (`int`, `optional`, defaults to 32):
+ The ratio of the spatial resolution of the output to the resolution of the input image.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the Transformer encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for attached classifiers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ aspp_out_channels (`int`, `optional`, defaults to 256):
+ Number of output channels used in the ASPP layer for semantic segmentation.
+ atrous_rates (`List[int]`, *optional*, defaults to `[6, 12, 18]`):
+ Dilation (atrous) factors used in the ASPP layer for semantic segmentation.
+ aspp_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the ASPP layer for semantic segmentation.
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function of the semantic segmentation model.
+
+ Example:
+
+ ```python
+ >>> from transformers import MobileViTConfig, MobileViTModel
+
+ >>> # Initializing a mobilevit-small style configuration
+ >>> configuration = MobileViTConfig()
+
+ >>> # Initializing a model from the mobilevit-small style configuration
+ >>> model = MobileViTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "mobilevit"
+
+ def __init__(
+ self,
+ num_channels=3,
+ image_size=256,
+ patch_size=2,
+ hidden_sizes=[144, 192, 240],
+ neck_hidden_sizes=[16, 32, 64, 96, 128, 160, 640],
+ num_attention_heads=4,
+ mlp_ratio=2.0,
+ expand_ratio=4.0,
+ hidden_act="silu",
+ conv_kernel_size=3,
+ output_stride=32,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.0,
+ classifier_dropout_prob=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ qkv_bias=True,
+ aspp_out_channels=256,
+ atrous_rates=[6, 12, 18],
+ aspp_dropout_prob=0.1,
+ semantic_loss_ignore_index=255,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_sizes = hidden_sizes
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.expand_ratio = expand_ratio
+ self.hidden_act = hidden_act
+ self.conv_kernel_size = conv_kernel_size
+ self.output_stride = output_stride
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.classifier_dropout_prob = classifier_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+
+ # decode head attributes for semantic segmentation
+ self.aspp_out_channels = aspp_out_channels
+ self.atrous_rates = atrous_rates
+ self.aspp_dropout_prob = aspp_dropout_prob
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
+
+
+class MobileViTOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict([("pixel_values", {0: "batch"})])
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "image-classification":
+ return OrderedDict([("logits", {0: "batch"})])
+ else:
+ return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})])
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
diff --git a/src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py b/src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py
new file mode 100644
index 000000000000..7f3e07f7b540
--- /dev/null
+++ b/src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py
@@ -0,0 +1,312 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert MobileViT checkpoints from the ml-cvnets library."""
+
+
+import argparse
+import json
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+import requests
+from huggingface_hub import hf_hub_download
+from transformers import (
+ MobileViTConfig,
+ MobileViTFeatureExtractor,
+ MobileViTForImageClassification,
+ MobileViTForSemanticSegmentation,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_mobilevit_config(mobilevit_name):
+ config = MobileViTConfig()
+
+ # size of the architecture
+ if "mobilevit_s" in mobilevit_name:
+ config.hidden_sizes = [144, 192, 240]
+ config.neck_hidden_sizes = [16, 32, 64, 96, 128, 160, 640]
+ elif "mobilevit_xs" in mobilevit_name:
+ config.hidden_sizes = [96, 120, 144]
+ config.neck_hidden_sizes = [16, 32, 48, 64, 80, 96, 384]
+ elif "mobilevit_xxs" in mobilevit_name:
+ config.hidden_sizes = [64, 80, 96]
+ config.neck_hidden_sizes = [16, 16, 24, 48, 64, 80, 320]
+ config.hidden_dropout_prob = 0.05
+ config.expand_ratio = 2.0
+
+ if mobilevit_name.startswith("deeplabv3_"):
+ config.image_size = 512
+ config.output_stride = 16
+ config.num_labels = 21
+ filename = "pascal-voc-id2label.json"
+ else:
+ config.num_labels = 1000
+ filename = "imagenet-1k-id2label.json"
+
+ repo_id = "datasets/huggingface/label-files"
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ return config
+
+
+def rename_key(name, base_model=False):
+ for i in range(1, 6):
+ if f"layer_{i}." in name:
+ name = name.replace(f"layer_{i}.", f"encoder.layer.{i - 1}.")
+
+ if "conv_1." in name:
+ name = name.replace("conv_1.", "conv_stem.")
+ if ".block." in name:
+ name = name.replace(".block.", ".")
+
+ if "exp_1x1" in name:
+ name = name.replace("exp_1x1", "expand_1x1")
+ if "red_1x1" in name:
+ name = name.replace("red_1x1", "reduce_1x1")
+ if ".local_rep.conv_3x3." in name:
+ name = name.replace(".local_rep.conv_3x3.", ".conv_kxk.")
+ if ".local_rep.conv_1x1." in name:
+ name = name.replace(".local_rep.conv_1x1.", ".conv_1x1.")
+ if ".norm." in name:
+ name = name.replace(".norm.", ".normalization.")
+ if ".conv." in name:
+ name = name.replace(".conv.", ".convolution.")
+ if ".conv_proj." in name:
+ name = name.replace(".conv_proj.", ".conv_projection.")
+
+ for i in range(0, 2):
+ for j in range(0, 4):
+ if f".{i}.{j}." in name:
+ name = name.replace(f".{i}.{j}.", f".{i}.layer.{j}.")
+
+ for i in range(2, 6):
+ for j in range(0, 4):
+ if f".{i}.{j}." in name:
+ name = name.replace(f".{i}.{j}.", f".{i}.")
+ if "expand_1x1" in name:
+ name = name.replace("expand_1x1", "downsampling_layer.expand_1x1")
+ if "conv_3x3" in name:
+ name = name.replace("conv_3x3", "downsampling_layer.conv_3x3")
+ if "reduce_1x1" in name:
+ name = name.replace("reduce_1x1", "downsampling_layer.reduce_1x1")
+
+ for i in range(2, 5):
+ if f".global_rep.{i}.weight" in name:
+ name = name.replace(f".global_rep.{i}.weight", ".layernorm.weight")
+ if f".global_rep.{i}.bias" in name:
+ name = name.replace(f".global_rep.{i}.bias", ".layernorm.bias")
+
+ if ".global_rep." in name:
+ name = name.replace(".global_rep.", ".transformer.")
+ if ".pre_norm_mha.0." in name:
+ name = name.replace(".pre_norm_mha.0.", ".layernorm_before.")
+ if ".pre_norm_mha.1.out_proj." in name:
+ name = name.replace(".pre_norm_mha.1.out_proj.", ".attention.output.dense.")
+ if ".pre_norm_ffn.0." in name:
+ name = name.replace(".pre_norm_ffn.0.", ".layernorm_after.")
+ if ".pre_norm_ffn.1." in name:
+ name = name.replace(".pre_norm_ffn.1.", ".intermediate.dense.")
+ if ".pre_norm_ffn.4." in name:
+ name = name.replace(".pre_norm_ffn.4.", ".output.dense.")
+ if ".transformer." in name:
+ name = name.replace(".transformer.", ".transformer.layer.")
+
+ if ".aspp_layer." in name:
+ name = name.replace(".aspp_layer.", ".")
+ if ".aspp_pool." in name:
+ name = name.replace(".aspp_pool.", ".")
+ if "seg_head." in name:
+ name = name.replace("seg_head.", "segmentation_head.")
+ if "segmentation_head.classifier.classifier." in name:
+ name = name.replace("segmentation_head.classifier.classifier.", "segmentation_head.classifier.")
+
+ if "classifier.fc." in name:
+ name = name.replace("classifier.fc.", "classifier.")
+ elif (not base_model) and ("segmentation_head." not in name):
+ name = "mobilevit." + name
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, model, base_model=False):
+ if base_model:
+ model_prefix = ""
+ else:
+ model_prefix = "mobilevit."
+
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if key[:8] == "encoder.":
+ key = key[8:]
+
+ if "qkv" in key:
+ key_split = key.split(".")
+ layer_num = int(key_split[0][6:]) - 1
+ transformer_num = int(key_split[3])
+ layer = model.get_submodule(f"{model_prefix}encoder.layer.{layer_num}")
+ dim = layer.transformer.layer[transformer_num].attention.attention.all_head_size
+ prefix = (
+ f"{model_prefix}encoder.layer.{layer_num}.transformer.layer.{transformer_num}.attention.attention."
+ )
+ if "weight" in key:
+ orig_state_dict[prefix + "query.weight"] = val[:dim, :]
+ orig_state_dict[prefix + "key.weight"] = val[dim : dim * 2, :]
+ orig_state_dict[prefix + "value.weight"] = val[-dim:, :]
+ else:
+ orig_state_dict[prefix + "query.bias"] = val[:dim]
+ orig_state_dict[prefix + "key.bias"] = val[dim : dim * 2]
+ orig_state_dict[prefix + "value.bias"] = val[-dim:]
+ else:
+ orig_state_dict[rename_key(key, base_model)] = val
+
+ return orig_state_dict
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_movilevit_checkpoint(mobilevit_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our MobileViT structure.
+ """
+ config = get_mobilevit_config(mobilevit_name)
+
+ # load original state_dict
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+
+ # load š¤ model
+ if mobilevit_name.startswith("deeplabv3_"):
+ model = MobileViTForSemanticSegmentation(config).eval()
+ else:
+ model = MobileViTForImageClassification(config).eval()
+
+ new_state_dict = convert_state_dict(state_dict, model)
+ model.load_state_dict(new_state_dict)
+
+ # Check outputs on an image, prepared by MobileViTFeatureExtractor
+ feature_extractor = MobileViTFeatureExtractor(crop_size=config.image_size, size=config.image_size + 32)
+ encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
+ outputs = model(**encoding)
+ logits = outputs.logits
+
+ if mobilevit_name.startswith("deeplabv3_"):
+ assert logits.shape == (1, 21, 32, 32)
+
+ if mobilevit_name == "deeplabv3_mobilevit_s":
+ expected_logits = torch.tensor(
+ [
+ [[6.2065, 6.1292, 6.2070], [6.1079, 6.1254, 6.1747], [6.0042, 6.1071, 6.1034]],
+ [[-6.9253, -6.8653, -7.0398], [-7.3218, -7.3983, -7.3670], [-7.1961, -7.2482, -7.1569]],
+ [[-4.4723, -4.4348, -4.3769], [-5.3629, -5.4632, -5.4598], [-5.1587, -5.3402, -5.5059]],
+ ]
+ )
+ elif mobilevit_name == "deeplabv3_mobilevit_xs":
+ expected_logits = torch.tensor(
+ [
+ [[5.4449, 5.5733, 5.6314], [5.1815, 5.3930, 5.5963], [5.1656, 5.4333, 5.4853]],
+ [[-9.4423, -9.7766, -9.6714], [-9.1581, -9.5720, -9.5519], [-9.1006, -9.6458, -9.5703]],
+ [[-7.7721, -7.3716, -7.1583], [-8.4599, -8.0624, -7.7944], [-8.4172, -7.8366, -7.5025]],
+ ]
+ )
+ elif mobilevit_name == "deeplabv3_mobilevit_xxs":
+ expected_logits = torch.tensor(
+ [
+ [[6.9811, 6.9743, 7.3123], [7.1777, 7.1931, 7.3938], [7.5633, 7.8050, 7.8901]],
+ [[-10.5536, -10.2332, -10.2924], [-10.2336, -9.8624, -9.5964], [-10.8840, -10.8158, -10.6659]],
+ [[-3.4938, -3.0631, -2.8620], [-3.4205, -2.8135, -2.6875], [-3.4179, -2.7945, -2.8750]],
+ ]
+ )
+ else:
+ raise ValueError(f"Unknown mobilevit_name: {mobilevit_name}")
+
+ assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-4)
+ else:
+ assert logits.shape == (1, 1000)
+
+ if mobilevit_name == "mobilevit_s":
+ expected_logits = torch.tensor([-0.9866, 0.2392, -1.1241])
+ elif mobilevit_name == "mobilevit_xs":
+ expected_logits = torch.tensor([-2.4761, -0.9399, -1.9587])
+ elif mobilevit_name == "mobilevit_xxs":
+ expected_logits = torch.tensor([-1.9364, -1.2327, -0.4653])
+ else:
+ raise ValueError(f"Unknown mobilevit_name: {mobilevit_name}")
+
+ assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4)
+
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {mobilevit_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving feature extractor to {pytorch_dump_folder_path}")
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model_mapping = {
+ "mobilevit_s": "mobilevit-small",
+ "mobilevit_xs": "mobilevit-x-small",
+ "mobilevit_xxs": "mobilevit-xx-small",
+ "deeplabv3_mobilevit_s": "deeplabv3-mobilevit-small",
+ "deeplabv3_mobilevit_xs": "deeplabv3-mobilevit-x-small",
+ "deeplabv3_mobilevit_xxs": "deeplabv3-mobilevit-xx-small",
+ }
+
+ print("Pushing to the hub...")
+ model_name = model_mapping[mobilevit_name]
+ feature_extractor.push_to_hub(model_name, organization="apple")
+ model.push_to_hub(model_name, organization="apple")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--mobilevit_name",
+ default="mobilevit_s",
+ type=str,
+ help=(
+ "Name of the MobileViT model you'd like to convert. Should be one of 'mobilevit_s', 'mobilevit_xs',"
+ " 'mobilevit_xxs', 'deeplabv3_mobilevit_s', 'deeplabv3_mobilevit_xs', 'deeplabv3_mobilevit_xxs'."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoint_path", required=True, type=str, help="Path to the original state dict (.pt file)."
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the š¤ hub."
+ )
+
+ args = parser.parse_args()
+ convert_movilevit_checkpoint(
+ args.mobilevit_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
+ )
diff --git a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py
new file mode 100644
index 000000000000..51e022b809c9
--- /dev/null
+++ b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py
@@ -0,0 +1,153 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for MobileViT."""
+
+from typing import Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a MobileViT feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int` or `Tuple(int)`, *optional*, defaults to 288):
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
+ integer is provided, then the input will be resized to match the shorter side. Only has an effect if
+ `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`int`, *optional*, defaults to 256):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ do_flip_channel_order (`bool`, *optional*, defaults to `True`):
+ Whether to flip the color channels from RGB to BGR.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=288,
+ resample=Image.BILINEAR,
+ do_center_crop=True,
+ crop_size=256,
+ do_flip_channel_order=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_flip_channel_order = do_flip_channel_order
+
+ def __call__(
+ self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + normalization)
+ if self.do_resize and self.size is not None:
+ images = [
+ self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
+ for image in images
+ ]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [self.center_crop(image, self.crop_size) for image in images]
+
+ images = [self.to_numpy_array(image) for image in images]
+
+ # the pretrained checkpoints assume images are BGR, not RGB
+ if self.do_flip_channel_order:
+ images = [self.flip_channel_order(image) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py
new file mode 100755
index 000000000000..fadfc4de3052
--- /dev/null
+++ b/src/transformers/models/mobilevit/modeling_mobilevit.py
@@ -0,0 +1,1087 @@
+# coding=utf-8
+# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
+""" PyTorch MobileViT model."""
+
+
+import math
+from typing import Dict, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+ SemanticSegmenterOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_mobilevit import MobileViTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# General docstring
+_CONFIG_FOR_DOC = "MobileViTConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "MobileViTFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
+_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "apple/mobilevit-small",
+ "apple/mobilevit-x-small",
+ "apple/mobilevit-xx-small",
+ "apple/deeplabv3-mobilevit-small",
+ "apple/deeplabv3-mobilevit-x-small",
+ "apple/deeplabv3-mobilevit-xx-small",
+ # See all MobileViT models at https://huggingface.co/models?filter=mobilevit
+]
+
+
+def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
+ """
+ Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
+ original TensorFlow repo. It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ """
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_value < 0.9 * value:
+ new_value += divisor
+ return int(new_value)
+
+
+class MobileViTConvLayer(nn.Module):
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ dilation: int = 1,
+ use_normalization: bool = True,
+ use_activation: Union[bool, str] = True,
+ ) -> None:
+ super().__init__()
+ padding = int((kernel_size - 1) / 2) * dilation
+
+ if in_channels % groups != 0:
+ raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
+ if out_channels % groups != 0:
+ raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
+
+ self.convolution = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode="zeros",
+ )
+
+ if use_normalization:
+ self.normalization = nn.BatchNorm2d(
+ num_features=out_channels,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ )
+ else:
+ self.normalization = None
+
+ if use_activation:
+ if isinstance(use_activation, str):
+ self.activation = ACT2FN[use_activation]
+ elif isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ else:
+ self.activation = None
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ features = self.convolution(features)
+ if self.normalization is not None:
+ features = self.normalization(features)
+ if self.activation is not None:
+ features = self.activation(features)
+ return features
+
+
+class MobileViTInvertedResidual(nn.Module):
+ """
+ Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
+ """
+
+ def __init__(
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
+ ) -> None:
+ super().__init__()
+ expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
+
+ if stride not in [1, 2]:
+ raise ValueError(f"Invalid stride {stride}.")
+
+ self.use_residual = (stride == 1) and (in_channels == out_channels)
+
+ self.expand_1x1 = MobileViTConvLayer(
+ config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
+ )
+
+ self.conv_3x3 = MobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=expanded_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=expanded_channels,
+ dilation=dilation,
+ )
+
+ self.reduce_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=expanded_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation=False,
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+
+ features = self.expand_1x1(features)
+ features = self.conv_3x3(features)
+ features = self.reduce_1x1(features)
+
+ return residual + features if self.use_residual else features
+
+
+class MobileViTMobileNetLayer(nn.Module):
+ def __init__(
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
+ ) -> None:
+ super().__init__()
+
+ self.layer = nn.ModuleList()
+ for i in range(num_stages):
+ layer = MobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if i == 0 else 1,
+ )
+ self.layer.append(layer)
+ in_channels = out_channels
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ for layer_module in self.layer:
+ features = layer_module(features)
+ return features
+
+
+class MobileViTSelfAttention(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+
+ if hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size {hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ return context_layer
+
+
+class MobileViTSelfOutput(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(hidden_size, hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class MobileViTAttention(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
+ super().__init__()
+ self.attention = MobileViTSelfAttention(config, hidden_size)
+ self.output = MobileViTSelfOutput(config, hidden_size)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ self_outputs = self.attention(hidden_states)
+ attention_output = self.output(self_outputs)
+ return attention_output
+
+
+class MobileViTIntermediate(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(hidden_size, intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class MobileViTOutput(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.dense = nn.Linear(intermediate_size, hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class MobileViTTransformerLayer(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.attention = MobileViTAttention(config, hidden_size)
+ self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
+ self.output = MobileViTOutput(config, hidden_size, intermediate_size)
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ attention_output = self.attention(self.layernorm_before(hidden_states))
+ hidden_states = attention_output + hidden_states
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.output(layer_output, hidden_states)
+ return layer_output
+
+
+class MobileViTTransformer(nn.Module):
+ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
+ super().__init__()
+
+ self.layer = nn.ModuleList()
+ for _ in range(num_stages):
+ transformer_layer = MobileViTTransformerLayer(
+ config,
+ hidden_size=hidden_size,
+ intermediate_size=int(hidden_size * config.mlp_ratio),
+ )
+ self.layer.append(transformer_layer)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for layer_module in self.layer:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class MobileViTLayer(nn.Module):
+ """
+ MobileViT block: https://arxiv.org/abs/2110.02178
+ """
+
+ def __init__(
+ self,
+ config: MobileViTConfig,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ hidden_size: int,
+ num_stages: int,
+ dilation: int = 1,
+ ) -> None:
+ super().__init__()
+ self.patch_width = config.patch_size
+ self.patch_height = config.patch_size
+
+ if stride == 2:
+ self.downsampling_layer = MobileViTInvertedResidual(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride if dilation == 1 else 1,
+ dilation=dilation // 2 if dilation > 1 else 1,
+ )
+ in_channels = out_channels
+ else:
+ self.downsampling_layer = None
+
+ self.conv_kxk = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=config.conv_kernel_size,
+ )
+
+ self.conv_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=hidden_size,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ )
+
+ self.transformer = MobileViTTransformer(
+ config,
+ hidden_size=hidden_size,
+ num_stages=num_stages,
+ )
+
+ self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ self.conv_projection = MobileViTConvLayer(
+ config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
+ )
+
+ self.fusion = MobileViTConvLayer(
+ config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
+ )
+
+ def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = int(patch_width * patch_height)
+
+ batch_size, channels, orig_height, orig_width = features.shape
+
+ new_height = int(math.ceil(orig_height / patch_height) * patch_height)
+ new_width = int(math.ceil(orig_width / patch_width) * patch_width)
+
+ interpolate = False
+ if new_width != orig_width or new_height != orig_height:
+ # Note: Padding can be done, but then it needs to be handled in attention function.
+ features = nn.functional.interpolate(
+ features, size=(new_height, new_width), mode="bilinear", align_corners=False
+ )
+ interpolate = True
+
+ # number of patches along width and height
+ num_patch_width = new_width // patch_width
+ num_patch_height = new_height // patch_height
+ num_patches = num_patch_height * num_patch_width
+
+ # convert from shape (batch_size, channels, orig_height, orig_width)
+ # to the shape (batch_size * patch_area, num_patches, channels)
+ patches = features.reshape(
+ batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
+ )
+ patches = patches.transpose(1, 2)
+ patches = patches.reshape(batch_size, channels, num_patches, patch_area)
+ patches = patches.transpose(1, 3)
+ patches = patches.reshape(batch_size * patch_area, num_patches, -1)
+
+ info_dict = {
+ "orig_size": (orig_height, orig_width),
+ "batch_size": batch_size,
+ "channels": channels,
+ "interpolate": interpolate,
+ "num_patches": num_patches,
+ "num_patches_width": num_patch_width,
+ "num_patches_height": num_patch_height,
+ }
+ return patches, info_dict
+
+ def folding(self, patches: torch.Tensor, info_dict: Dict) -> torch.Tensor:
+ patch_width, patch_height = self.patch_width, self.patch_height
+ patch_area = int(patch_width * patch_height)
+
+ batch_size = info_dict["batch_size"]
+ channels = info_dict["channels"]
+ num_patches = info_dict["num_patches"]
+ num_patch_height = info_dict["num_patches_height"]
+ num_patch_width = info_dict["num_patches_width"]
+
+ # convert from shape (batch_size * patch_area, num_patches, channels)
+ # back to shape (batch_size, channels, orig_height, orig_width)
+ features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
+ features = features.transpose(1, 3)
+ features = features.reshape(
+ batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
+ )
+ features = features.transpose(1, 2)
+ features = features.reshape(
+ batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
+ )
+
+ if info_dict["interpolate"]:
+ features = nn.functional.interpolate(
+ features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
+ )
+
+ return features
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ # reduce spatial dimensions if needed
+ if self.downsampling_layer:
+ features = self.downsampling_layer(features)
+
+ residual = features
+
+ # local representation
+ features = self.conv_kxk(features)
+ features = self.conv_1x1(features)
+
+ # convert feature map to patches
+ patches, info_dict = self.unfolding(features)
+
+ # learn global representations
+ patches = self.transformer(patches)
+ patches = self.layernorm(patches)
+
+ # convert patches back to feature maps
+ features = self.folding(patches, info_dict)
+
+ features = self.conv_projection(features)
+ features = self.fusion(torch.cat((residual, features), dim=1))
+ return features
+
+
+class MobileViTEncoder(nn.Module):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layer = nn.ModuleList()
+ self.gradient_checkpointing = False
+
+ # segmentation architectures like DeepLab and PSPNet modify the strides
+ # of the classification backbones
+ dilate_layer_4 = dilate_layer_5 = False
+ if config.output_stride == 8:
+ dilate_layer_4 = True
+ dilate_layer_5 = True
+ elif config.output_stride == 16:
+ dilate_layer_5 = True
+
+ dilation = 1
+
+ layer_1 = MobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[0],
+ out_channels=config.neck_hidden_sizes[1],
+ stride=1,
+ num_stages=1,
+ )
+ self.layer.append(layer_1)
+
+ layer_2 = MobileViTMobileNetLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[1],
+ out_channels=config.neck_hidden_sizes[2],
+ stride=2,
+ num_stages=3,
+ )
+ self.layer.append(layer_2)
+
+ layer_3 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[2],
+ out_channels=config.neck_hidden_sizes[3],
+ stride=2,
+ hidden_size=config.hidden_sizes[0],
+ num_stages=2,
+ )
+ self.layer.append(layer_3)
+
+ if dilate_layer_4:
+ dilation *= 2
+
+ layer_4 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[3],
+ out_channels=config.neck_hidden_sizes[4],
+ stride=2,
+ hidden_size=config.hidden_sizes[1],
+ num_stages=4,
+ dilation=dilation,
+ )
+ self.layer.append(layer_4)
+
+ if dilate_layer_5:
+ dilation *= 2
+
+ layer_5 = MobileViTLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[4],
+ out_channels=config.neck_hidden_sizes[5],
+ stride=2,
+ hidden_size=config.hidden_sizes[2],
+ num_stages=3,
+ dilation=dilation,
+ )
+ self.layer.append(layer_5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutputWithNoAttention]:
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.layer):
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ )
+ else:
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+class MobileViTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MobileViTConfig
+ base_model_prefix = "mobilevit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, MobileViTEncoder):
+ module.gradient_checkpointing = value
+
+
+MOBILEVIT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MOBILEVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`MobileViTFeatureExtractor`]. See
+ [`MobileViTFeatureExtractor.__call__`] for details.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare MobileViT model outputting raw hidden-states without any specific head on top.",
+ MOBILEVIT_START_DOCSTRING,
+)
+class MobileViTModel(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig, expand_output: bool = True):
+ super().__init__(config)
+ self.config = config
+ self.expand_output = expand_output
+
+ self.conv_stem = MobileViTConvLayer(
+ config,
+ in_channels=config.num_channels,
+ out_channels=config.neck_hidden_sizes[0],
+ kernel_size=3,
+ stride=2,
+ )
+
+ self.encoder = MobileViTEncoder(config)
+
+ if self.expand_output:
+ self.conv_1x1_exp = MobileViTConvLayer(
+ config,
+ in_channels=config.neck_hidden_sizes[5],
+ out_channels=config.neck_hidden_sizes[6],
+ kernel_size=1,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ """Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
+ """
+ for layer_index, heads in heads_to_prune.items():
+ mobilevit_layer = self.encoder.layer[layer_index]
+ if isinstance(mobilevit_layer, MobileViTLayer):
+ for transformer_layer in mobilevit_layer.transformer.layer:
+ transformer_layer.attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.conv_stem(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.expand_output:
+ last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
+
+ # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
+ pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
+ else:
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = None
+
+ if not return_dict:
+ output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
+ return output + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ MOBILEVIT_START_DOCSTRING,
+)
+class MobileViTForImageClassification(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = MobileViTModel(config)
+
+ # Classifier head
+ self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
+ self.classifier = (
+ nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(self.dropout(pooled_output))
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+class MobileViTASPPPooling(nn.Module):
+ def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
+ super().__init__()
+
+ self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
+
+ self.conv_1x1 = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ use_normalization=True,
+ use_activation="relu",
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ spatial_size = features.shape[-2:]
+ features = self.global_pool(features)
+ features = self.conv_1x1(features)
+ features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
+ return features
+
+
+class MobileViTASPP(nn.Module):
+ """
+ ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+
+ in_channels = config.neck_hidden_sizes[-2]
+ out_channels = config.aspp_out_channels
+
+ if len(config.atrous_rates) != 3:
+ raise ValueError("Expected 3 values for atrous_rates")
+
+ self.convs = nn.ModuleList()
+
+ in_projection = MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_activation="relu",
+ )
+ self.convs.append(in_projection)
+
+ self.convs.extend(
+ [
+ MobileViTConvLayer(
+ config,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ dilation=rate,
+ use_activation="relu",
+ )
+ for rate in config.atrous_rates
+ ]
+ )
+
+ pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
+ self.convs.append(pool_layer)
+
+ self.project = MobileViTConvLayer(
+ config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
+ )
+
+ self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ pyramid = []
+ for conv in self.convs:
+ pyramid.append(conv(features))
+ pyramid = torch.cat(pyramid, dim=1)
+
+ pooled_features = self.project(pyramid)
+ pooled_features = self.dropout(pooled_features)
+ return pooled_features
+
+
+class MobileViTDeepLabV3(nn.Module):
+ """
+ DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
+ """
+
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__()
+ self.aspp = MobileViTASPP(config)
+
+ self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
+
+ self.classifier = MobileViTConvLayer(
+ config,
+ in_channels=config.aspp_out_channels,
+ out_channels=config.num_labels,
+ kernel_size=1,
+ use_normalization=False,
+ use_activation=False,
+ bias=True,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ features = self.aspp(hidden_states[-1])
+ features = self.dropout(features)
+ features = self.classifier(features)
+ return features
+
+
+@add_start_docstrings(
+ """
+ MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
+ """,
+ MOBILEVIT_START_DOCSTRING,
+)
+class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
+ def __init__(self, config: MobileViTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.mobilevit = MobileViTModel(config, expand_output=False)
+ self.segmentation_head = MobileViTDeepLabV3(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SemanticSegmenterOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small")
+ >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mobilevit(
+ pixel_values,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ logits = self.segmentation_head(encoder_hidden_states)
+
+ loss = None
+ if labels is not None:
+ if self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+ else:
+ # upsample logits to the images' original size
+ upsampled_logits = nn.functional.interpolate(
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+ loss = loss_fct(upsampled_logits, labels)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
diff --git a/src/transformers/models/mpnet/__init__.py b/src/transformers/models/mpnet/__init__.py
index 54c2c7b8419a..5b3bc0dbd375 100644
--- a/src/transformers/models/mpnet/__init__.py
+++ b/src/transformers/models/mpnet/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_mpnet": ["MPNetTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mpnet_fast"] = ["MPNetTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mpnet"] = [
"MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"MPNetForMaskedLM",
@@ -42,7 +59,12 @@
"MPNetPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mpnet"] = [
"TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMPNetEmbeddings",
@@ -61,10 +83,20 @@
from .configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
from .tokenization_mpnet import MPNetTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mpnet_fast import MPNetTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mpnet import (
MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
MPNetForMaskedLM,
@@ -77,7 +109,12 @@
MPNetPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mpnet import (
TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMPNetEmbeddings,
diff --git a/src/transformers/models/mpnet/tokenization_mpnet.py b/src/transformers/models/mpnet/tokenization_mpnet.py
index f092e6a311a9..713a528d557a 100644
--- a/src/transformers/models/mpnet/tokenization_mpnet.py
+++ b/src/transformers/models/mpnet/tokenization_mpnet.py
@@ -175,8 +175,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/mpnet/tokenization_mpnet_fast.py b/src/transformers/models/mpnet/tokenization_mpnet_fast.py
index c913f85682cc..f2fe4fe4fe8f 100644
--- a/src/transformers/models/mpnet/tokenization_mpnet_fast.py
+++ b/src/transformers/models/mpnet/tokenization_mpnet_fast.py
@@ -163,8 +163,9 @@ def mask_token(self) -> str:
MPNet tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
comprise the space before the **.
"""
- if self._mask_token is None and self.verbose:
- logger.error("Using mask_token, but it is not set yet.")
+ if self._mask_token is None:
+ if self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py
index dd576cb0b25c..f6e717bd875b 100644
--- a/src/transformers/models/mt5/__init__.py
+++ b/src/transformers/models/mt5/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -42,31 +43,59 @@
MT5TokenizerFast = T5TokenizerFast
-_import_structure = {
- "configuration_mt5": ["MT5Config"],
-}
+_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]
-if is_flax_available():
- _import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
if TYPE_CHECKING:
- from .configuration_mt5 import MT5Config
-
- if is_torch_available():
+ from .configuration_mt5 import MT5Config, MT5OnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
- if is_flax_available():
- from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
else:
import sys
diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py
index d6a343f77dbc..3e72831ad25f 100644
--- a/src/transformers/models/mt5/configuration_mt5.py
+++ b/src/transformers/models/mt5/configuration_mt5.py
@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 model configuration"""
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
@@ -117,6 +119,21 @@ def __init__(
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
@property
def hidden_size(self):
return self.d_model
@@ -128,3 +145,29 @@ def num_attention_heads(self):
@property
def num_hidden_layers(self):
return self.num_layers
+
+
+# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
+class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = {
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
+ }
+ if self.use_past:
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
+ common_inputs["decoder_input_ids"] = {0: "batch"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
+ else:
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
+
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+
+ return common_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py
index d45ea49645d3..4f2fa5b9fb39 100644
--- a/src/transformers/models/mt5/modeling_flax_mt5.py
+++ b/src/transformers/models/mt5/modeling_flax_mt5.py
@@ -14,8 +14,10 @@
# limitations under the License.
""" Flax mT5 model."""
+import numpy as np
+
from ...utils import logging
-from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
+from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config
@@ -25,6 +27,19 @@
_TOKENIZER_FOR_DOC = "T5Tokenizer"
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = np.zeros_like(input_ids)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
class FlaxMT5Model(FlaxT5Model):
r"""
This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage
@@ -42,8 +57,7 @@ class FlaxMT5Model(FlaxT5Model):
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
- >>> with tokenizer.as_target_tokenizer():
- ... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
+ >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)
>>> hidden_states = outputs.last_hidden_state
@@ -52,6 +66,32 @@ class FlaxMT5Model(FlaxT5Model):
config_class = MT5Config
+class FlaxMT5EncoderModel(FlaxT5EncoderModel):
+ r"""
+ This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation
+ alongside usage examples.
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxT5EncoderModel, T5Tokenizer
+
+ >>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small")
+ >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
+
+ >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
+ >>> summary = "Weiter Verhandlung in Syrien."
+ >>> inputs = tokenizer(article, return_tensors="np")
+
+ >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
+
+ >>> outputs = model(input_ids=inputs["input_ids"])
+ >>> hidden_states = outputs.last_hidden_state
+ ```"""
+ model_type = "mt5"
+ config_class = MT5Config
+
+
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
r"""
This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate
@@ -69,8 +109,7 @@ class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
- >>> with tokenizer.as_target_tokenizer():
- ... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
+ >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
>>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
>>> logits = outputs.logits
diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py
index 314198c69a9a..c562b011522d 100644
--- a/src/transformers/models/mt5/modeling_mt5.py
+++ b/src/transformers/models/mt5/modeling_mt5.py
@@ -40,8 +40,7 @@ class MT5Model(T5Model):
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(summary, return_tensors="pt")
+ >>> labels = tokenizer(text_target=summary, return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state
@@ -49,13 +48,13 @@ class MT5Model(T5Model):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
]
@@ -73,21 +72,19 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
- >>> inputs = tokenizer(article, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(summary, return_tensors="pt")
+ >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
- >>> outputs = model(**inputs, labels=labels["input_ids"])
+ >>> outputs = model(**inputs)
>>> loss = outputs.loss
```"""
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
@@ -112,8 +109,8 @@ class MT5EncoderModel(T5EncoderModel):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
diff --git a/src/transformers/models/mt5/modeling_tf_mt5.py b/src/transformers/models/mt5/modeling_tf_mt5.py
index 2808b8421a16..71aa0bb66a7a 100644
--- a/src/transformers/models/mt5/modeling_tf_mt5.py
+++ b/src/transformers/models/mt5/modeling_tf_mt5.py
@@ -40,8 +40,7 @@ class TFMT5Model(TFT5Model):
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="tf")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(summary, return_tensors="tf")
+ >>> labels = tokenizer(text_target=summary, return_tensors="tf")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state
@@ -64,11 +63,9 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
- >>> inputs = tokenizer(article, return_tensors="tf")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(summary, return_tensors="tf")
+ >>> inputs = tokenizer(article, text_target=summary, return_tensors="tf")
- >>> outputs = model(**inputs, labels=labels["input_ids"])
+ >>> outputs = model(**inputs)
>>> loss = outputs.loss
```"""
diff --git a/src/transformers/models/mvp/__init__.py b/src/transformers/models/mvp/__init__.py
new file mode 100644
index 000000000000..865b958d3911
--- /dev/null
+++ b/src/transformers/models/mvp/__init__.py
@@ -0,0 +1,83 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_mvp": ["MVP_PRETRAINED_CONFIG_ARCHIVE_MAP", "MvpConfig", "MvpOnnxConfig"],
+ "tokenization_mvp": ["MvpTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_mvp_fast"] = ["MvpTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_mvp"] = [
+ "MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MvpForCausalLM",
+ "MvpForConditionalGeneration",
+ "MvpForQuestionAnswering",
+ "MvpForSequenceClassification",
+ "MvpModel",
+ "MvpPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_mvp import MVP_PRETRAINED_CONFIG_ARCHIVE_MAP, MvpConfig, MvpOnnxConfig
+ from .tokenization_mvp import MvpTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_mvp_fast import MvpTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_mvp import (
+ MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MvpForCausalLM,
+ MvpForConditionalGeneration,
+ MvpForQuestionAnswering,
+ MvpForSequenceClassification,
+ MvpModel,
+ MvpPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/mvp/configuration_mvp.py b/src/transformers/models/mvp/configuration_mvp.py
new file mode 100644
index 000000000000..63a006b8e429
--- /dev/null
+++ b/src/transformers/models/mvp/configuration_mvp.py
@@ -0,0 +1,182 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" MVP model configuration"""
+import warnings
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MVP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/config.json",
+}
+
+
+class MvpConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MvpModel`]. It is used to instantiate a MVP model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the MVP [RUCAIBox/mvp](https://huggingface.co/RUCAIBox/mvp)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50267):
+ Vocabulary size of the MVP model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MvpModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for classifier.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+ use_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not to use prompt.
+ prompt_length (`int`, *optional*, defaults to 100):
+ The length of prompt.
+ prompt_mid_dim (`int`, *optional*, defaults to 800):
+ Dimensionality of the "intermediate" layer in prompt.
+ Example:
+
+ ```python
+ >>> from transformers import MvpModel, MvpConfig
+
+ >>> # Initializing a MVP RUCAIBox/mvp style configuration
+ >>> configuration = MvpConfig()
+
+ >>> # Initializing a model from the RUCAIBox/mvp style configuration
+ >>> model = MvpModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "mvp"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=50267,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ is_encoder_decoder=True,
+ decoder_start_token_id=2,
+ forced_eos_token_id=2,
+ use_prompt=False,
+ prompt_length=100,
+ prompt_mid_dim=800,
+ **kwargs
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.use_prompt = use_prompt
+ self.prompt_length = prompt_length
+ self.prompt_mid_dim = prompt_mid_dim
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
+ self.forced_bos_token_id = self.bos_token_id
+ warnings.warn(
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
+ "The config can simply be saved and uploaded again to be fixed."
+ )
diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py
new file mode 100644
index 000000000000..d3d239c4cff1
--- /dev/null
+++ b/src/transformers/models/mvp/modeling_mvp.py
@@ -0,0 +1,2050 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch MVP model."""
+import copy
+import math
+import random
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ Seq2SeqQuestionAnsweringModelOutput,
+ Seq2SeqSequenceClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_mvp import MvpConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "RUCAIBox/mvp"
+_CONFIG_FOR_DOC = "MvpConfig"
+_TOKENIZER_FOR_DOC = "MvpTokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+MVP_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "RUCAIBox/mvp",
+ "RUCAIBox/mvp-data-to-text",
+ "RUCAIBox/mvp-open-dialog",
+ "RUCAIBox/mvp-question-answering",
+ "RUCAIBox/mvp-question-generation",
+ "RUCAIBox/mvp-story",
+ "RUCAIBox/mvp-summarization",
+ "RUCAIBox/mvp-task-dialog",
+ "RUCAIBox/mtl-data-to-text",
+ "RUCAIBox/mtl-multi-task",
+ "RUCAIBox/mtl-open-dialog",
+ "RUCAIBox/mtl-question-answering",
+ "RUCAIBox/mtl-question-generation",
+ "RUCAIBox/mtl-story",
+ "RUCAIBox/mtl-summarization",
+ # See all MVP models at https://huggingface.co/models?filter=mvp
+]
+
+
+# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
+ mask_cond = torch.arange(mask.size(-1))
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP
+class MvpLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # MVP is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ bsz, seq_len = input_ids_shape[:2]
+ positions = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ )
+ return super().forward(positions + self.offset)
+
+
+class MvpAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ attn_prompt: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ if attn_prompt is not None:
+ key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2)
+ value_states = torch.cat([attn_prompt[1].expand(bsz, -1, -1, -1), value_states], dim=2)
+ if attention_mask is not None:
+ prompt_mask = torch.zeros(bsz, 1, tgt_len, attn_prompt[0].size(1)).to(attention_mask.device)
+ attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1))
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class MvpEncoderLayer(nn.Module):
+ def __init__(self, config: MvpConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = MvpAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ self_attn_prompt: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
+ `(2, encoder_attention_heads, pro_len, head_dim)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ attn_prompt=self_attn_prompt,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class MvpDecoderLayer(nn.Module):
+ def __init__(self, config: MvpConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = MvpAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = MvpAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ self_attn_prompt: Optional[torch.Tensor] = None,
+ cross_attn_prompt: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
+ `(2, decoder_attention_heads, pro_len, head_dim)`.
+ cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape
+ `(2, decoder_attention_heads, pro_len, head_dim)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ attn_prompt=self_attn_prompt,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ attn_prompt=cross_attn_prompt,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MVP
+class MvpClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(
+ self,
+ input_dim: int,
+ inner_dim: int,
+ num_classes: int,
+ pooler_dropout: float,
+ ):
+ super().__init__()
+ self.dense = nn.Linear(input_dim, inner_dim)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+ self.out_proj = nn.Linear(inner_dim, num_classes)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.out_proj(hidden_states)
+ return hidden_states
+
+
+class MvpPrompt(nn.Module):
+ """Layer-wise prompt for encoder or decoder."""
+
+ def __init__(self, config, num_layers, num_heads):
+ super().__init__()
+ self.prompt_length = config.prompt_length
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.head_dim = config.d_model // num_heads
+ self.dropout = nn.Dropout(p=config.dropout)
+ self.prompt_embedding = nn.Embedding(config.prompt_length, config.d_model)
+ self.prompt_trans = nn.Sequential(
+ nn.Linear(config.d_model, config.prompt_mid_dim),
+ nn.GELU(),
+ nn.Linear(config.prompt_mid_dim, num_layers * 2 * config.d_model),
+ )
+
+ def forward(self, prompt_ids: torch.Tensor) -> Tuple[torch.Tensor]:
+ prompt = self.prompt_trans(self.prompt_embedding(prompt_ids))
+ prompt = prompt.view(self.prompt_length, self.num_layers * 2, self.num_heads, self.head_dim)
+ prompt = self.dropout(prompt)
+ prompt = prompt.permute([1, 2, 0, 3]).split(2)
+ return prompt
+
+
+class MvpPreTrainedModel(PreTrainedModel):
+ config_class = MvpConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+MVP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MvpConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MVP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`MvpTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`MvpTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
+
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+ for denoising pre-training following the paper.
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+MVP_CONDITIONAL_GENERATION_EXAMPLE = r"""
+ Example of summarization:
+
+ Fine-tuning a model
+ ```python
+ >>> import torch
+ >>> from transformers import MvpTokenizer, MvpForConditionalGeneration
+
+ >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ >>> model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp")
+
+ >>> inputs = tokenizer(
+ ... "Summarize: You may want to stick it to your boss and leave your job, but don't do it if these are your reasons.",
+ ... return_tensors="pt",
+ ... )
+ >>> labels = tokenizer("Bad Reasons To Quit Your Job", return_tensors="pt")["input_ids"]
+
+ >>> loss = model(**inputs, labels=labels).loss
+ >>> loss.backward()
+ ```
+
+ Inference after the model fine-tuned
+ ```python
+ >>> with torch.no_grad():
+ ... generated_ids = model.generate(**inputs)
+
+ >>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ ```
+"""
+
+MVP_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example of single-label classification:
+
+ Fine-tuning a model on `num_labels` classes
+ ```python
+ >>> import torch
+ >>> from transformers import MvpTokenizer, MvpForSequenceClassification
+
+ >>> num_labels = 2 # for example, this is a binary classification task
+ >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ >>> model = MvpForSequenceClassification.from_pretrained("RUCAIBox/mvp", num_labels=num_labels)
+
+ >>> inputs = tokenizer("Classify: Hello, my dog is cute", return_tensors="pt")
+ >>> labels = torch.tensor(1) # the real label for inputs
+
+ >>> loss = model(**inputs, labels=labels).loss
+ >>> loss.backward()
+ ```
+
+ Inference after the model fine-tuned
+ ```python
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> predicted_class_id = logits.argmax()
+ ```
+"""
+
+MVP_QUESTION_ANSWERING_SAMPLE = r"""
+ Example:
+
+ Fine-tuning a model for extrative question answering, and our model also supports generative question answering
+ using `BartForConditionalGeneration`
+ ```python
+ >>> import torch
+ >>> from transformers import MvpTokenizer, MvpForQuestionAnswering
+
+ >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ >>> model = MvpForQuestionAnswering.from_pretrained("RUCAIBox/mvp")
+
+ >>> inputs = tokenizer(
+ ... "Answer the following question: Who was Jim Henson? [SEP] Jim Henson was a nice puppet",
+ ... return_tensors="pt",
+ ... )
+ >>> target_start_index = torch.tensor([18])
+ >>> target_end_index = torch.tensor([19])
+
+ >>> loss = model(**inputs, start_positions=target_start_index, end_positions=target_end_index).loss
+ >>> loss.backward()
+ ```
+
+ Inference after the model fine-tuned
+ ```python
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> answer_start_index = outputs.start_logits.argmax()
+ >>> answer_end_index = outputs.end_logits.argmax()
+
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
+ >>> predict_answer = tokenizer.decode(predict_answer_tokens)
+ ```
+"""
+
+
+class MvpEncoder(MvpPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`MvpEncoderLayer`].
+
+ Args:
+ config: MvpConfig
+ embed_tokens (nn.Embedding): output embedding
+ use_prompt (bool): whether to use prompt
+ """
+
+ def __init__(
+ self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False
+ ):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
+
+ self.embed_positions = MvpLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ self.layers = nn.ModuleList([MvpEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
+
+ self.use_prompt = use_prompt
+ if use_prompt:
+ self.prompt_length = config.prompt_length
+ self.self_attn_prompt = MvpPrompt(
+ config,
+ config.encoder_layers,
+ config.encoder_attention_heads,
+ )
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`MvpTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ embed_pos = self.embed_positions(input_shape)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # layer-wise prompt
+ if self.use_prompt:
+ prompt_ids = torch.arange(self.prompt_length).to(self.device)
+ self_attn_prompt = self.self_attn_prompt(prompt_ids)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ (self_attn_prompt[idx] if self.use_prompt else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class MvpDecoder(MvpPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MvpDecoderLayer`]
+
+ Args:
+ config: MvpConfig
+ embed_tokens (nn.Embedding): output embedding
+ use_prompt (bool): whether to use prompt
+ """
+
+ def __init__(
+ self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False
+ ):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+
+ self.embed_positions = MvpLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+ self.layers = nn.ModuleList([MvpDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ self.use_prompt = use_prompt
+ if use_prompt:
+ self.prompt_length = config.prompt_length
+ self.self_attn_prompt = MvpPrompt(
+ config,
+ config.decoder_layers,
+ config.decoder_attention_heads,
+ )
+ self.cross_attn_prompt = MvpPrompt(
+ config,
+ config.decoder_layers,
+ config.decoder_attention_heads,
+ )
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
+ ).to(inputs_embeds.device)
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`MvpTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+
+ # embed positions
+ positions = self.embed_positions(input_shape, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # layer-wise prompt
+ if self.use_prompt:
+ prompt_ids = torch.arange(self.prompt_length).to(self.device)
+ self_attn_prompt = self.self_attn_prompt(prompt_ids)
+ cross_attn_prompt = self.cross_attn_prompt(prompt_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ self_attn_prompt[idx] if self.use_prompt else None,
+ cross_attn_prompt[idx] if self.use_prompt else None,
+ None,
+ )
+ else:
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
+ cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare MVP Model outputting raw hidden-states without any specific head on top.",
+ MVP_START_DOCSTRING,
+)
+class MvpModel(MvpPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
+
+ def __init__(self, config: MvpConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.use_prompt = config.use_prompt
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = MvpEncoder(config, self.shared, config.use_prompt)
+ self.decoder = MvpDecoder(config, self.shared, config.use_prompt)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def set_lightweight_tuning(self):
+ assert self.use_prompt, "If you want to use lightweight tuning, make sure that `use_prompt=True`."
+
+ self.requires_grad_(False)
+ self.encoder.self_attn_prompt.requires_grad_(True)
+ self.decoder.self_attn_prompt.requires_grad_(True)
+ self.decoder.cross_attn_prompt.requires_grad_(True)
+
+ @add_start_docstrings_to_model_forward(MVP_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Seq2SeqModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
+
+ # different to other models, Mvp automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ if input_ids is None:
+ raise ValueError(
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
+ "passed, `input_ids` cannot be `None`. Please pass either "
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
+ )
+
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING
+)
+class MvpForConditionalGeneration(MvpPreTrainedModel):
+ def __init__(self, config: MvpConfig):
+ super().__init__(config)
+ self.model = MvpModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
+ self._resize_final_logits_bias(new_num_tokens)
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_lightweight_tuning(self):
+ self.model.set_lightweight_tuning()
+ self.lm_head.requires_grad_(False)
+
+ @add_start_docstrings_to_model_forward(MVP_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ @add_end_docstrings(MVP_CONDITIONAL_GENERATION_EXAMPLE)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs
+ ):
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ Mvp model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
+ tasks.
+ """,
+ MVP_START_DOCSTRING,
+)
+class MvpForSequenceClassification(MvpPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
+
+ def __init__(self, config: MvpConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.model = MvpModel(config)
+ self.classification_head = MvpClassificationHead(
+ config.d_model,
+ config.d_model,
+ config.num_labels,
+ config.classifier_dropout,
+ )
+
+ self.model._init_weights(self.classification_head.dense)
+ self.model._init_weights(self.classification_head.out_proj)
+
+ def set_lightweight_tuning(self):
+ self.model.set_lightweight_tuning()
+ self.classification_head.requires_grad_(False)
+
+ @add_start_docstrings_to_model_forward(MVP_INPUTS_DOCSTRING)
+ @add_end_docstrings(MVP_SEQUENCE_CLASSIFICATION_SAMPLE)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ if input_ids is None and inputs_embeds is not None:
+ raise NotImplementedError(
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0] # last hidden state
+
+ eos_mask = input_ids.eq(self.config.eos_token_id)
+
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
+ raise ValueError("All examples must have the same number of tokens.")
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
+ :, -1, :
+ ]
+ logits = self.classification_head(sentence_representation)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.config.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.config.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ MVP Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer
+ on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ MVP_START_DOCSTRING,
+)
+class MvpForQuestionAnswering(MvpPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ config.num_labels = 2
+ self.num_labels = config.num_labels
+
+ self.model = MvpModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.model._init_weights(self.qa_outputs)
+
+ def set_lightweight_tuning(self):
+ self.model.set_lightweight_tuning()
+ self.qa_outputs.requires_grad_(False)
+
+ @add_start_docstrings_to_model_forward(MVP_INPUTS_DOCSTRING)
+ @add_end_docstrings(MVP_QUESTION_ANSWERING_SAMPLE)
+ def forward(
+ self,
+ input_ids: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if start_positions is not None and end_positions is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return Seq2SeqQuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Mvp
+class MvpDecoderWrapper(MvpPreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = MvpDecoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+class MvpForCausalLM(MvpPreTrainedModel):
+ def __init__(self, config):
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.model = MvpDecoderWrapper(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def set_lightweight_tuning(self):
+ self.model.set_lightweight_tuning()
+ self.lm_head.requires_grad_(False)
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`MvpTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import MvpTokenizer, MvpForCausalLM
+
+ >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ >>> model = MvpForCausalLM.from_pretrained("RUCAIBox/mvp", add_cross_attention=False)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> list(logits.shape)
+ [1, 8, 50267]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/src/transformers/models/mvp/tokenization_mvp.py b/src/transformers/models/mvp/tokenization_mvp.py
new file mode 100644
index 000000000000..3d5d606d63b5
--- /dev/null
+++ b/src/transformers/models/mvp/tokenization_mvp.py
@@ -0,0 +1,404 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from functools import lru_cache
+from typing import List, Optional, Tuple
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
+
+# See all MVP models at https://huggingface.co/models?filter=mvp
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/vocab.json",
+ },
+ "added_tokens.json": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/added_tokens.json",
+ },
+ "merges_file": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/merges.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "RUCAIBox/mvp": 1024,
+}
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("”"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class MvpTokenizer(PreTrainedTokenizer):
+ """
+ Constructs a MVP tokenizer, which is smilar to the RoBERTa tokenizer, using byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import MvpTokenizer
+ >>> tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ >>> tokenizer("Hello world")['input_ids']
+ [0, 31414, 232, 2]
+ >>> tokenizer(" Hello world")['input_ids']
+ [0, 20920, 232, 2]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (MVP tokenizer detect beginning of words by the preceding space).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=False,
+ **kwargs
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A MVP sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. MVP does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
+ text = " " + text
+ return (text, kwargs)
diff --git a/src/transformers/models/mvp/tokenization_mvp_fast.py b/src/transformers/models/mvp/tokenization_mvp_fast.py
new file mode 100644
index 000000000000..00b7a5c6651e
--- /dev/null
+++ b/src/transformers/models/mvp/tokenization_mvp_fast.py
@@ -0,0 +1,287 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import List, Optional, Tuple
+
+from tokenizers import pre_tokenizers, processors
+
+from ...tokenization_utils_base import AddedToken, BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_mvp import MvpTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+# See all MVP models at https://huggingface.co/models?filter=mvp
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/vocab.json",
+ },
+ "added_tokens.json": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/added_tokens.json",
+ },
+ "merges_file": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/merges.txt",
+ },
+ "tokenizer_file": {
+ "RUCAIBox/mvp": "https://huggingface.co/RUCAIBox/mvp/resolve/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "RUCAIBox/mvp": 1024,
+}
+
+
+class MvpTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" MVP tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
+ using byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import MvpTokenizerFast
+ >>> tokenizer = MvpTokenizerFast.from_pretrained("RUCAIBox/mvp")
+ >>> tokenizer("Hello world")['input_ids']
+ [0, 31414, 232, 2]
+ >>> tokenizer(" Hello world")['input_ids']
+ [0, 20920, 232, 2]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (MVP tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether the post processing step should trim offsets to avoid including whitespaces.
+ """
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = MvpTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=False,
+ trim_offsets=True,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
+ tokenizer_component = "post_processor"
+ tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
+ if tokenizer_component_instance:
+ state = json.loads(tokenizer_component_instance.__getstate__())
+
+ # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
+ if "sep" in state:
+ state["sep"] = tuple(state["sep"])
+ if "cls" in state:
+ state["cls"] = tuple(state["cls"])
+
+ changes_to_apply = False
+
+ if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ state["add_prefix_space"] = add_prefix_space
+ changes_to_apply = True
+
+ if state.get("trim_offsets", trim_offsets) != trim_offsets:
+ state["trim_offsets"] = trim_offsets
+ changes_to_apply = True
+
+ if changes_to_apply:
+ component_class = getattr(processors, state.pop("type"))
+ new_value = component_class(**state)
+ setattr(self.backend_tokenizer, tokenizer_component, new_value)
+
+ @property
+ def mask_token(self) -> str:
+ """
+ `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
+ having been set.
+
+ MVP tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
+ comprise the space before the **.
+ """
+ if self._mask_token is None:
+ if self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
+ return None
+ return str(self._mask_token)
+
+ @mask_token.setter
+ def mask_token(self, value):
+ """
+ Overriding the default behavior of the mask token to have it eat the space before it.
+
+ This is needed to preserve backward compatibility with all the previously used models based on Mvp.
+ """
+ # Mask token behave like a normal word, i.e. include the space before it
+ # So we set lstrip to True
+ value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
+ self._mask_token = value
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ if is_split_into_words and not self.add_prefix_space:
+ raise ValueError(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ if is_split_into_words and not self.add_prefix_space:
+ raise ValueError(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. MVP does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
diff --git a/src/transformers/models/nezha/__init__.py b/src/transformers/models/nezha/__init__.py
new file mode 100644
index 000000000000..9811ee325250
--- /dev/null
+++ b/src/transformers/models/nezha/__init__.py
@@ -0,0 +1,74 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_nezha": ["NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP", "NezhaConfig"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_nezha"] = [
+ "NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "NezhaForNextSentencePrediction",
+ "NezhaForMaskedLM",
+ "NezhaForPreTraining",
+ "NezhaForMultipleChoice",
+ "NezhaForQuestionAnswering",
+ "NezhaForSequenceClassification",
+ "NezhaForTokenClassification",
+ "NezhaModel",
+ "NezhaPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_nezha import (
+ NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ NezhaForMaskedLM,
+ NezhaForMultipleChoice,
+ NezhaForNextSentencePrediction,
+ NezhaForPreTraining,
+ NezhaForQuestionAnswering,
+ NezhaForSequenceClassification,
+ NezhaForTokenClassification,
+ NezhaModel,
+ NezhaPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/nezha/configuration_nezha.py b/src/transformers/models/nezha/configuration_nezha.py
new file mode 100644
index 000000000000..eb57016cd45d
--- /dev/null
+++ b/src/transformers/models/nezha/configuration_nezha.py
@@ -0,0 +1,110 @@
+from transformers import PretrainedConfig
+
+
+NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "sijunhe/nezha-cn-base": "https://huggingface.co/sijunhe/nezha-cn-base/resolve/main/config.json",
+}
+
+
+class NezhaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Nezha
+ [sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, optional, defaults to 21128):
+ Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the
+ *inputs_ids* passed to the forward method of [`NezhaModel`].
+ embedding_size (`int`, optional, defaults to 128):
+ Dimensionality of vocabulary embeddings.
+ hidden_size (`int`, optional, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, optional, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, optional, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, optional, defaults to 3072):
+ The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, optional, defaults to "gelu"):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ hidden_dropout_prob (`float`, optional, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, optional, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, optional, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, optional, defaults to 2):
+ The vocabulary size of the *token_type_ids* passed into [`NezhaModel`].
+ initializer_range (`float`, optional, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, optional, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ classifier_dropout (`float`, optional, defaults to 0.1):
+ The dropout ratio for attached classifiers.
+
+ Example:
+
+ ```python
+ >>> from transformers import NezhaConfig, NezhaModel
+
+ >>> # Initializing an Nezha configuration
+ >>> configuration = NezhaConfig()
+
+ >>> # Initializing a model from the Nezha-base style configuration model
+ >>> model = NezhaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ pretrained_config_archive_map = NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP
+ model_type = "nezha"
+
+ def __init__(
+ self,
+ vocab_size=21128,
+ embedding_size=128,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ max_relative_position=64,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ classifier_dropout=0.1,
+ pad_token_id=0,
+ bos_token_id=2,
+ eos_token_id=3,
+ use_cache=True,
+ **kwargs
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.max_relative_position = max_relative_position
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py
new file mode 100644
index 000000000000..4fa38b3ed48f
--- /dev/null
+++ b/src/transformers/models/nezha/modeling_nezha.py
@@ -0,0 +1,1727 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Nezha model."""
+
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_nezha import NezhaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "sijunhe/nezha-cn-base"
+_CONFIG_FOR_DOC = "NezhaConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+
+NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "sijunhe/nezha-cn-base",
+ "sijunhe/nezha-cn-large",
+ "sijunhe/nezha-base-wwm",
+ "sijunhe/nezha-large-wwm",
+ # See all Nezha models at https://huggingface.co/models?filter=nezha
+]
+
+
+def load_tf_weights_in_nezha(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class NezhaRelativePositionsEncoding(nn.Module):
+ """Implement the Functional Relative Position Encoding"""
+
+ def __init__(self, length, depth, max_relative_position=127):
+ super().__init__()
+ vocab_size = max_relative_position * 2 + 1
+ range_vec = torch.arange(length)
+ range_mat = range_vec.repeat(length).view(length, length)
+ distance_mat = range_mat - torch.t(range_mat)
+ distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
+ final_mat = distance_mat_clipped + max_relative_position
+
+ embeddings_table = torch.zeros(vocab_size, depth)
+ position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth))
+ embeddings_table[:, 0::2] = torch.sin(position * div_term)
+ embeddings_table[:, 1::2] = torch.cos(position * div_term)
+
+ flat_relative_positions_matrix = final_mat.view(-1)
+ one_hot_relative_positions_matrix = torch.nn.functional.one_hot(
+ flat_relative_positions_matrix, num_classes=vocab_size
+ ).float()
+ positions_encoding = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
+ my_shape = list(final_mat.size())
+ my_shape.append(depth)
+ positions_encoding = positions_encoding.view(my_shape)
+ self.register_buffer("positions_encoding", positions_encoding)
+
+ def forward(self, length):
+ return self.positions_encoding[:length, :length, :]
+
+
+class NezhaEmbeddings(nn.Module):
+ """Construct the embeddings from word and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if is_torch_greater_than_1_6:
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros((1, config.max_position_embeddings), dtype=torch.long),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class NezhaSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.relative_positions_encoding = NezhaRelativePositionsEncoding(
+ length=config.max_position_embeddings,
+ depth=self.attention_head_size,
+ max_relative_position=config.max_relative_position,
+ )
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()
+ relations_keys = self.relative_positions_encoding(to_seq_length)
+ query_layer_t = query_layer.permute(2, 0, 1, 3)
+
+ query_layer_r = query_layer_t.contiguous().view(
+ from_seq_length, batch_size * num_attention_heads, self.attention_head_size
+ )
+ key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
+ key_position_scores_r = key_position_scores.view(
+ from_seq_length, batch_size, num_attention_heads, from_seq_length
+ )
+ key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
+ attention_scores = attention_scores + key_position_scores_r_t
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in NezhaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ relations_values = self.relative_positions_encoding(to_seq_length)
+ attention_probs_t = attention_probs.permute(2, 0, 1, 3)
+ attentions_probs_r = attention_probs_t.contiguous().view(
+ from_seq_length, batch_size * num_attention_heads, to_seq_length
+ )
+ value_position_scores = torch.matmul(attentions_probs_r, relations_values)
+ value_position_scores_r = value_position_scores.view(
+ from_seq_length, batch_size, num_attention_heads, self.attention_head_size
+ )
+ value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
+ context_layer = context_layer + value_position_scores_r_t
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Nezha
+class NezhaSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class NezhaAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = NezhaSelfAttention(config)
+ self.output = NezhaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Nezha
+class NezhaIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Nezha
+class NezhaOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class NezhaLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = NezhaAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = NezhaAttention(config)
+ self.intermediate = NezhaIntermediate(config)
+ self.output = NezhaOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Nezha
+class NezhaEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([NezhaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Nezha
+class NezhaPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Nezha
+class NezhaPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Nezha
+class NezhaLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = NezhaPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Nezha
+class NezhaOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = NezhaLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->Nezha
+class NezhaOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->Nezha
+class NezhaPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = NezhaLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class NezhaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NezhaConfig
+ load_tf_weights = load_tf_weights_in_nezha
+ base_model_prefix = "nezha"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_missing = [r"positions_encoding"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, NezhaEncoder):
+ module.gradient_checkpointing = value
+
+
+@dataclass
+class NezhaForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`NezhaForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: torch.FloatTensor = None
+ seq_relationship_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+NEZHA_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`NezhaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+NEZHA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Nezha Model transformer outputting raw hidden-states without any specific head on top.",
+ NEZHA_START_DOCSTRING,
+)
+class NezhaModel(NezhaPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = NezhaEmbeddings(config)
+ self.encoder = NezhaEncoder(config)
+
+ self.pooler = NezhaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForPreTraining(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ self.cls = NezhaPreTrainingHeads(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ next_sentence_label: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], NezhaForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import BertTokenizer, NezhaForPreTraining
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("sijunhe/nezha-cn-base")
+ >>> model = NezhaForPreTraining.from_pretrained("sijunhe/nezha-cn-base")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return NezhaForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING)
+class NezhaForMaskedLM(NezhaPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", r"positions_encoding"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `NezhaForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ self.cls = NezhaOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ if self.config.pad_token_id is None:
+ raise ValueError("The PAD token should be defined for generation")
+
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Nezha Model with a `next sentence prediction (classification)` head on top.""",
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForNextSentencePrediction(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ self.cls = NezhaOnlyNSPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import BertTokenizer, NezhaForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("sijunhe/nezha-cn-base")
+ >>> model = NezhaForNextSentencePrediction.from_pretrained("sijunhe/nezha-cn-base")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```
+ """
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForSequenceClassification(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.nezha = NezhaModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForMultipleChoice(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+ print(pooled_output.shape)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ print(logits.shape)
+ print(num_choices)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForTokenClassification(NezhaPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForQuestionAnswering(NezhaPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/nllb/__init__.py b/src/transformers/models/nllb/__init__.py
new file mode 100644
index 000000000000..a678bf527440
--- /dev/null
+++ b/src/transformers/models/nllb/__init__.py
@@ -0,0 +1,68 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_nllb"] = ["NllbTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_nllb_fast"] = ["NllbTokenizerFast"]
+
+
+if TYPE_CHECKING:
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_nllb import NllbTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_nllb_fast import NllbTokenizerFast
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/nllb/tokenization_nllb.py b/src/transformers/models/nllb/tokenization_nllb.py
new file mode 100644
index 000000000000..6a326fd3ca10
--- /dev/null
+++ b/src/transformers/models/nllb/tokenization_nllb.py
@@ -0,0 +1,401 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+SPIECE_UNDERLINE = "ā"
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "facebook/nllb-200-distilled-600M": (
+ "https://huggingface.co/facebook/nllb-200-distilled-600M/blob/main/sentencepiece.bpe.model"
+ ),
+ }
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "facebook/nllb-200-distilled-600M": 1024,
+}
+
+# fmt: off
+FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn']
+# fmt: on
+
+
+class NllbTokenizer(PreTrainedTokenizer):
+ """
+ Construct an NLLB tokenizer.
+
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
+
+ Examples:
+
+ ```python
+ >>> from transformers import NllbTokenizer
+
+ >>> tokenizer = NllbTokenizer.from_pretrained(
+ ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
+ ... )
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
+ ```
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenizer_file (`str`, *optional*):
+ The path to a tokenizer file to use instead of the vocab file.
+ src_lang (`str`, *optional*):
+ The language to use as source language for translation.
+ tgt_lang (`str`, *optional*):
+ The language to use as target language for translation.
+ sp_model_kwargs (`Dict[str, str]`):
+ Additional keyword arguments to pass to the model initialization.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ model_input_names = ["input_ids", "attention_mask"]
+
+ prefix_tokens: List[int] = []
+ suffix_tokens: List[int] = []
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ tokenizer_file=None,
+ src_lang=None,
+ tgt_lang=None,
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ additional_special_tokens=None,
+ **kwargs
+ ):
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ tokenizer_file=tokenizer_file,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ additional_special_tokens=additional_special_tokens,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
+ # fairseq | '' | '' | '' | '' | 'an' | 'ān' | 'ām' | 'āt' | 'āk' | 'āa'
+ # spm | '' | '' | '' | 'an' | 'ān' | 'ām' | 'āt' | 'āk' | 'āa' | 'ās'
+
+ # Mimic fairseq token-to-id alignment for the first 4 token
+ self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3}
+
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
+ self.fairseq_offset = 1
+
+ self.sp_model_size = len(self.sp_model)
+ self.lang_code_to_id = {
+ code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
+ }
+ self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
+ self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
+
+ self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
+ self._additional_special_tokens = list(self.lang_code_to_id.keys())
+
+ if additional_special_tokens is not None:
+ # Only add those special tokens if they are not already there.
+ self._additional_special_tokens.extend(
+ [t for t in additional_special_tokens if t not in self._additional_special_tokens]
+ )
+
+ self._src_lang = src_lang if src_lang is not None else "eng_Latn"
+ self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
+ self.tgt_lang = tgt_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ prefix_ones = [1] * len(self.prefix_tokens)
+ suffix_ones = [1] * len(self.suffix_tokens)
+ if token_ids_1 is None:
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> List[str]:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.fairseq_tokens_to_ids:
+ return self.fairseq_tokens_to_ids[token]
+ spm_id = self.sp_model.PieceToId(token)
+
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.fairseq_ids_to_tokens:
+ return self.fairseq_ids_to_tokens[index]
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: List[str],
+ src_lang: str = "eng_Latn",
+ tgt_texts: Optional[List[str]] = None,
+ tgt_lang: str = "fra_Latn",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang) -> None:
+ """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
+ self.cur_lang_code = self.lang_code_to_id[src_lang]
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
+ """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
+ self.cur_lang_code = self.lang_code_to_id[lang]
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
diff --git a/src/transformers/models/nllb/tokenization_nllb_fast.py b/src/transformers/models/nllb/tokenization_nllb_fast.py
new file mode 100644
index 000000000000..1afe27f43b4e
--- /dev/null
+++ b/src/transformers/models/nllb/tokenization_nllb_fast.py
@@ -0,0 +1,336 @@
+# coding=utf-8
+# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+from tokenizers import processors
+
+from ...tokenization_utils import AddedToken, BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_nllb import NllbTokenizer
+else:
+ NllbTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "facebook/nllb-200-distilled-600M": (
+ "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/sentencepiece.bpe.model"
+ ),
+ },
+ "tokenizer_file": {
+ "facebook/nllb-200-distilled-600M": (
+ "https://huggingface.co/facebook/nllb-200-distilled-600M/resolve/main/tokenizer.json"
+ ),
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "facebook/nllb-large-en-ro": 1024,
+ "facebook/nllb-200-distilled-600M": 1024,
+}
+
+# fmt: off
+FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn']
+# fmt: on
+
+
+class NllbTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" NLLB tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
+
+ Examples:
+
+ ```python
+ >>> from transformers import NllbTokenizerFast
+
+ >>> tokenizer = NllbTokenizerFast.from_pretrained(
+ ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
+ ... )
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
+ ```
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenizer_file (`str`, *optional*):
+ The path to a tokenizer file to use instead of the vocab file.
+ src_lang (`str`, *optional*):
+ The language to use as source language for translation.
+ tgt_lang (`str`, *optional*):
+ The language to use as target language for translation.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = NllbTokenizer
+
+ prefix_tokens: List[int] = []
+ suffix_tokens: List[int] = []
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ src_lang=None,
+ tgt_lang=None,
+ additional_special_tokens=None,
+ **kwargs
+ ):
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+ self.can_save_slow_tokenizer = False if not self.vocab_file else True
+
+ _additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
+
+ if additional_special_tokens is not None:
+ # Only add those special tokens if they are not already there.
+ _additional_special_tokens.extend(
+ [t for t in additional_special_tokens if t not in _additional_special_tokens]
+ )
+
+ self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
+ self.lang_code_to_id = {
+ lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
+ }
+
+ self._src_lang = src_lang if src_lang is not None else "eng_Latn"
+ self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
+ self.tgt_lang = tgt_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ @property
+ def src_lang(self) -> str:
+ return self._src_lang
+
+ @src_lang.setter
+ def src_lang(self, new_src_lang: str) -> None:
+ self._src_lang = new_src_lang
+ self.set_src_lang_special_tokens(self._src_lang)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def _build_translation_inputs(
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
+ ):
+ """Used by translation pipeline, to prepare inputs for the generate function"""
+ if src_lang is None or tgt_lang is None:
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
+ self.src_lang = src_lang
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
+ inputs["forced_bos_token_id"] = tgt_lang_id
+ return inputs
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: List[str],
+ src_lang: str = "eng_Latn",
+ tgt_texts: Optional[List[str]] = None,
+ tgt_lang: str = "fra_Latn",
+ **kwargs,
+ ) -> BatchEncoding:
+ self.src_lang = src_lang
+ self.tgt_lang = tgt_lang
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
+
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
+
+ def set_src_lang_special_tokens(self, src_lang) -> None:
+ """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
+ """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/models/nystromformer/__init__.py b/src/transformers/models/nystromformer/__init__.py
index d3df751dd4f6..a239e435f97b 100644
--- a/src/transformers/models/nystromformer/__init__.py
+++ b/src/transformers/models/nystromformer/__init__.py
@@ -18,14 +18,19 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_nystromformer": ["NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NystromformerConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_nystromformer"] = [
"NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"NystromformerForMaskedLM",
@@ -42,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_nystromformer import (
NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
NystromformerForMaskedLM,
diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py
index b5813af781b7..e1f352d2c897 100755
--- a/src/transformers/models/nystromformer/modeling_nystromformer.py
+++ b/src/transformers/models/nystromformer/modeling_nystromformer.py
@@ -20,7 +20,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -34,7 +33,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig
@@ -68,7 +72,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai/__init__.py
index 3abba0b781bc..8aaaaa62a989 100644
--- a/src/transformers/models/openai/__init__.py
+++ b/src/transformers/models/openai/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_openai": ["OpenAIGPTTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_openai_fast"] = ["OpenAIGPTTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_openai"] = [
"OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"OpenAIGPTDoubleHeadsModel",
@@ -40,7 +56,12 @@
"load_tf_weights_in_openai_gpt",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_openai"] = [
"TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFOpenAIGPTDoubleHeadsModel",
@@ -56,10 +77,20 @@
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .tokenization_openai import OpenAIGPTTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_openai_fast import OpenAIGPTTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
@@ -70,7 +101,12 @@
load_tf_weights_in_openai_gpt,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
index b57f2dd0339f..1b101aea0cc0 100755
--- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
@@ -64,8 +64,10 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
"--openai_config_file",
default="",
type=str,
- help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
- "This specifies the model architecture.",
+ help=(
+ "An optional config json file corresponding to the pre-trained OpenAI model. \n"
+ "This specifies the model architecture."
+ ),
)
args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(
diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py
index 2262db9aa8cf..e5e5da5da0c9 100644
--- a/src/transformers/models/openai/modeling_openai.py
+++ b/src/transformers/models/openai/modeling_openai.py
@@ -81,12 +81,14 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# Check that the token and position embeddings weight dimensions map those of the init parameters.
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
- f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
+ f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
+ f" {init_params[1].shape}"
)
if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
- f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
+ f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
+ f" {init_params[0].shape}"
)
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
@@ -477,7 +479,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
@@ -812,7 +814,7 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py
index 24a7935eb005..8a1761908628 100644
--- a/src/transformers/models/openai/modeling_tf_openai.py
+++ b/src/transformers/models/openai/modeling_tf_openai.py
@@ -239,17 +239,17 @@ def _prune_heads(self, heads_to_prune):
@unpack_inputs
def call(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- training=False,
- ):
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[Tuple, TFBaseModelOutput]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -556,6 +556,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
+ # OpenAIGPT does not have past caching features
+ self.supports_xla_generation = False
def get_output_embeddings(self):
return self.get_input_embeddings()
@@ -851,7 +853,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/openai/tokenization_openai.py b/src/transformers/models/openai/tokenization_openai.py
index ca21943a2359..40bb824cd718 100644
--- a/src/transformers/models/openai/tokenization_openai.py
+++ b/src/transformers/models/openai/tokenization_openai.py
@@ -215,7 +215,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py
new file mode 100644
index 000000000000..4e5508640972
--- /dev/null
+++ b/src/transformers/models/opt/__init__.py
@@ -0,0 +1,103 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {"configuration_opt": ["OPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OPTConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_opt"] = [
+ "OPT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OPTForCausalLM",
+ "OPTModel",
+ "OPTPreTrainedModel",
+ "OPTForSequenceClassification",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_opt"] = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_opt"] = [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_opt import OPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPTConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_opt import (
+ OPT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ OPTForCausalLM,
+ OPTForSequenceClassification,
+ OPTModel,
+ OPTPreTrainedModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/opt/configuration_opt.py b/src/transformers/models/opt/configuration_opt.py
new file mode 100644
index 000000000000..a101bb3e866f
--- /dev/null
+++ b/src/transformers/models/opt/configuration_opt.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" OPT model configuration"""
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+OPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/opt-125m": "https://huggingface.co/facebook/opt-125m/blob/main/config.json",
+ "facebook/opt-350m": "https://huggingface.co/facebook/opt-350m/blob/main/config.json",
+ "facebook/opt-1.3b": "https://huggingface.co/facebook/opt-1.3b/blob/main/config.json",
+ "facebook/opt-2.7b": "https://huggingface.co/facebook/opt-2.7b/blob/main/config.json",
+ "facebook/opt-6.7b": "https://huggingface.co/facebook/opt-6.7b/blob/main/config.json",
+ "facebook/opt-13b": "https://huggingface.co/facebook/opt-13b/blob/main/config.json",
+}
+
+
+class OPTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the OPT
+ [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50272):
+ Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`OPTModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ ffn_dim (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ do_layer_norm_before (`bool`, *optional*, defaults to `True`):
+ Whether to perform layer normalization before the attention block.
+ word_embed_proj_dim (`int`, *optional*):
+ `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to
+ `hidden_size`.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ layerdrop: (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
+ details.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+
+ Example:
+
+ ```python
+ >>> from transformers import OPTModel, OPTConfig
+
+ >>> # Initializing a OPT facebook/opt-large style configuration
+ >>> configuration = OPTConfig()
+
+ >>> # Initializing a model from the facebook/opt-large style configuration
+ >>> model = OPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "opt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50272,
+ hidden_size=768,
+ num_hidden_layers=12,
+ ffn_dim=3072,
+ max_position_embeddings=2048,
+ do_layer_norm_before=True,
+ _remove_final_layer_norm=False,
+ word_embed_proj_dim=None,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ num_attention_heads=12,
+ activation_function="relu",
+ layerdrop=0.0,
+ init_std=0.02,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=2,
+ eos_token_id=2,
+ **kwargs
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.num_attention_heads = num_attention_heads
+ self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
+ self.ffn_dim = ffn_dim
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.do_layer_norm_before = do_layer_norm_before
+
+ # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
+ # see https://github.com/facebookresearch/metaseq/pull/164
+ self._remove_final_layer_norm = _remove_final_layer_norm
diff --git a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 000000000000..ec1749daeff7
--- /dev/null
+++ b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert OPT checkpoint."""
+
+
+import argparse
+from pathlib import Path
+
+import torch
+
+from transformers import OPTConfig, OPTModel
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def load_checkpoint(checkpoint_path):
+ """Checkpoint path should end in model.pt"""
+ sd = torch.load(checkpoint_path, map_location="cpu")
+ if "model" in sd.keys():
+ sd = torch.load(checkpoint_path, map_location="cpu")["model"]
+
+ # pop unnecessary weights
+ keys_to_delete = [
+ "decoder.version",
+ "decoder.output_projection.weight",
+ ]
+ for key in keys_to_delete:
+ if key in sd:
+ sd.pop(key)
+
+ keys_to_rename = {
+ "decoder.project_in_dim.weight": "decoder.project_in.weight",
+ "decoder.project_out_dim.weight": "decoder.project_out.weight",
+ "decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
+ "decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
+ }
+ for old_key, new_key in keys_to_rename.items():
+ if old_key in sd:
+ sd[new_key] = sd.pop(old_key)
+
+ return sd
+
+
+@torch.no_grad()
+def convert_opt_checkpoint(checkpoint_path, pytorch_dump_folder_path, config=None):
+ """
+ Copy/paste/tweak model's weights to our BERT structure.
+ """
+ state_dict = load_checkpoint(checkpoint_path)
+
+ if config is not None:
+ config = OPTConfig.from_pretrained(config)
+ else:
+ config = OPTConfig()
+
+ model = OPTModel(config).half().eval()
+ model.load_state_dict(state_dict)
+
+ # Check results
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ model.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--fairseq_path",
+ type=str,
+ help=(
+ "path to fairseq checkpoint in correct format. You can find all checkpoints in the correct format here:"
+ " https://huggingface.co/models?other=opt_metasq"
+ ),
+ )
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--hf_config", default=None, type=str, help="Define HF config.")
+ args = parser.parse_args()
+ convert_opt_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, config=args.hf_config)
diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py
new file mode 100644
index 000000000000..5762fae14b09
--- /dev/null
+++ b/src/transformers/models/opt/modeling_flax_opt.py
@@ -0,0 +1,806 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Flax OPT model."""
+
+from functools import partial
+from typing import Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, logging
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT
+class FlaxOPTAttention(nn.Module):
+ config: OPTConfig
+ embed_dim: int
+ num_heads: int
+ dropout: float = 0.0
+ causal: bool = False
+ bias: bool = True
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self) -> None:
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ use_bias=self.bias,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+ self.out_proj = dense()
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states)
+ value_states = self.v_proj(key_value_states)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class FlaxOPTDecoderLayer(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.hidden_size
+ self.self_attn = FlaxOPTAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.num_attention_heads,
+ dropout=self.config.attention_dropout,
+ causal=True,
+ dtype=self.dtype,
+ )
+ self.do_layer_norm_before = self.config.do_layer_norm_before
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.fc1 = nn.Dense(
+ self.config.ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ init_cache: bool = False,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_states = (residual + hidden_states).reshape(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class FlaxOPTDecoderLayerCollection(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ self.layerdrop = self.config.layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ outputs = [hidden_states, all_hidden_states, all_self_attns]
+ return outputs
+
+
+class FlaxOPTLearnedPositionalEmbedding(nn.Embed):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def setup(self):
+ self.offset = 2
+ self.embedding = self.param(
+ "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype
+ )
+
+ def __call__(self, positions):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+
+ return super().__call__(positions + self.offset)
+
+
+class FlaxOPTDecoder(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ offset: int = 2
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.hidden_size
+ self.padding_idx = self.config.pad_token_id
+ self.max_target_positions = self.config.max_position_embeddings
+
+ self.embed_tokens = nn.Embed(
+ self.config.vocab_size,
+ self.config.word_embed_proj_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
+ self.config.max_position_embeddings,
+ embed_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ if self.config.word_embed_proj_dim != self.config.hidden_size:
+ self.project_in = nn.Dense(self.config.hidden_size, use_bias=False)
+ self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
+ # see https://github.com/facebookresearch/metaseq/pull/164
+ if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ else:
+ self.final_layer_norm = None
+
+ self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids)
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ positions = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + positions
+
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_state, all_hidden_states, attentions = self.layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if self.final_layer_norm is not None:
+ hidden_state = self.final_layer_norm(hidden_state)
+
+ if self.project_out is not None:
+ hidden_state = self.project_out(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_state,)
+
+ outputs = [hidden_state, all_hidden_states, attentions]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=all_hidden_states,
+ attentions=attentions,
+ )
+
+
+class FlaxOPTPreTrainedModel(FlaxPreTrainedModel):
+ config_class = OPTConfig
+ base_model_prefix: str = "model"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: OPTConfig,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ position_ids,
+ return_dict=False,
+ )
+
+ random_params = module_init_outputs["params"]
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ params: dict = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ dropout_rng: PRNGKey = None,
+ deterministic: bool = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ if position_ids is None:
+ position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxOPTAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxOPTModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype)
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ init_cache=False,
+ ):
+
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT
+class FlaxOPTModel(FlaxOPTPreTrainedModel):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ module_class = FlaxOPTModule
+
+
+append_call_sample_docstring(
+ FlaxOPTModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
+)
+
+
+@add_start_docstrings(
+ "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLMModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.model = FlaxOPTModule(config=self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=lm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for
+ autoregressive tasks.
+ """,
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
+ module_class = FlaxOPTForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxOPTForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxBaseModelOutput,
+ _CONFIG_FOR_DOC,
+)
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
new file mode 100644
index 000000000000..419c2391e4c7
--- /dev/null
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -0,0 +1,1115 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch OPT model."""
+import random
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+# SequenceClassification docstring
+_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
+_SEQ_CLASS_EXPECTED_LOSS = 1.71
+_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
+
+
+OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/opt-125m",
+ "facebook/opt-350m",
+ "facebook/opt-1.3b",
+ "facebook/opt-2.7b",
+ "facebook/opt-6.7b",
+ "facebook/opt-13b",
+ "facebook/opt-30b",
+ # See all OPT models at https://huggingface.co/models?filter=opt
+]
+
+
+def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
+ mask_cond = torch.arange(mask.size(-1))
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class OPTLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ attention_mask = attention_mask.long()
+
+ # create positions depending on attention_mask
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
+
+ # cut positions if `past_key_values_length` is > 0
+ positions = positions[:, past_key_values_length:]
+
+ return super().forward(positions + self.offset)
+
+
+class OPTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+ dtype_attn_weights = attn_weights.dtype
+
+ # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
+ if dtype_attn_weights == torch.float16:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
+ else:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class OPTDecoderLayer(nn.Module):
+ def __init__(self, config: OPTConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = OPTAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.do_layer_norm_before = config.do_layer_norm_before
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ hidden_states = (residual + hidden_states).view(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`OPTConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class OPTPreTrainedModel(PreTrainedModel):
+
+ config_class = OPTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["OPTDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (OPTDecoder)):
+ module.gradient_checkpointing = value
+
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class OPTDecoder(OPTPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
+
+ Args:
+ config: OPTConfig
+ """
+
+ def __init__(self, config: OPTConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
+ else:
+ self.project_out = None
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
+ else:
+ self.project_in = None
+
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
+ # see https://github.com/facebookresearch/metaseq/pull/164
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+ else:
+ self.final_layer_norm = None
+
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
+ ).to(inputs_embeds.device)
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if self.final_layer_norm is not None:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class OPTModel(OPTPreTrainedModel):
+ def __init__(self, config: OPTConfig):
+ super().__init__(config)
+ self.decoder = OPTDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+class OPTForCausalLM(OPTPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OPTModel(config)
+
+ # the lm_head weight is automatically tied to the embed tokens weight
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
+
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0]).contiguous()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The OPT Model transformer with a sequence classification head on top (linear layer).
+
+ [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ OPT_START_DOCSTRING,
+)
+class OPTForSequenceClassification(OPTPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config: OPTConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = OPTModel(config)
+ self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
+ output_type=SequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py
new file mode 100644
index 000000000000..633e972069ee
--- /dev/null
+++ b/src/transformers/models/opt/modeling_tf_opt.py
@@ -0,0 +1,1030 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 OPT model."""
+
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
+
+# Public API
+from ...modeling_tf_utils import (
+ DUMMY_INPUTS,
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSharedEmbeddings,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+# Causal LM output
+_CAUSAL_LM_EXPECTED_OUTPUT = "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+ mask_cond = tf.range(shape_list(mask)[-1])
+
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+ if past_key_values_length > 0:
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
+
+ def call(self, attention_mask, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ # create positions depending on attention_mask
+ positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1
+
+ # cut positions if `past_key_values_length` is > 0
+ positions = positions[:, past_key_values_length:]
+
+ return super().call(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT
+class TFOPTAttention(tf.keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = tf.keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
+ attention_mask: Optional[tf.Tensor] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ training: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+
+class TFOPTDecoderLayer(tf.keras.layers.Layer):
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.do_layer_norm_before = config.do_layer_norm_before
+ self.embed_dim = config.hidden_size
+ self.self_attn = TFOPTAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ name="self_attn",
+ is_decoder=True,
+ )
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+
+ self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.fc1 = tf.keras.layers.Dense(config.ffn_dim, name="fc1")
+ self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ training: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`tf.Tensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size
+ `(decoder_attention_heads,)`
+ past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return (hidden_states, self_attn_weights, present_key_value)
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+ If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
+ first positional argument :
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+
+
+ Args:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class TFOPTPreTrainedModel(TFPreTrainedModel):
+ """
+ TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel
+
+ Args:
+ config: OPTConfig
+ """
+
+ config_class = OPTConfig
+ base_model_prefix = "model"
+
+ @property
+ def dummy_inputs(self):
+ pad_token = 1
+ input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
+ dummy_inputs = {
+ "attention_mask": tf.math.not_equal(input_ids, pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+ @tf.function(
+ input_signature=[
+ {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+
+ return self.serving_output(output)
+
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFOPTDecoder(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.layerdrop = config.layerdrop
+ num_embeddings = config.max_position_embeddings
+ self.embed_tokens = TFSharedEmbeddings(
+ config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens"
+ )
+ self.embed_positions = TFOPTLearnedPositionalEmbedding(
+ num_embeddings,
+ config.hidden_size,
+ name="embed_positions",
+ )
+
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
+ # see https://github.com/facebookresearch/metaseq/pull/164
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
+ self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+ else:
+ self.final_layer_norm = None
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
+ self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens.vocab_size = new_embeddings.shape[0]
+ self.embed_tokens.weight = new_embeddings
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
+ # create causal mask
+ # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+ return combined_attention_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+ decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
+
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ present_key_values = () if use_cache else None
+
+ # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
+ if attn_mask is not None and tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_mask)[0],
+ len(self.layers),
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ hidden_states, layer_self_attn, present_key_value = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ past_key_value=past_key_value,
+ )
+
+ if use_cache:
+ present_key_values += (present_key_value,)
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+
+ if self.final_layer_norm is not None:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ else:
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@keras_serializable
+class TFOPTMainLayer(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.decoder = TFOPTDecoder(config, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.decoder(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare TF OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTModel(TFOPTPreTrainedModel):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.model.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=output.last_hidden_state,
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ )
+
+
+@add_start_docstrings(
+ """
+ The OPT Model transformer with a language modeling head on top.
+ """,
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_output_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ attention_mask = kwargs.get("attention_mask", None)
+
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ inputs = tf.expand_dims(inputs[:, -1], -1)
+
+ return {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @unpack_inputs
+ @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFCausalLMOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CAUSAL_LM_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ logits = self.model.decoder.embed_tokens(outputs[0], mode="linear")
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFCausalLMOutputWithPast(
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ loss=output.loss,
+ logits=output.logits,
+ )
diff --git a/src/transformers/models/owlvit/__init__.py b/src/transformers/models/owlvit/__init__.py
new file mode 100644
index 000000000000..8315df69faac
--- /dev/null
+++ b/src/transformers/models/owlvit/__init__.py
@@ -0,0 +1,100 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {
+ "configuration_owlvit": [
+ "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "OwlViTConfig",
+ "OwlViTTextConfig",
+ "OwlViTVisionConfig",
+ ],
+ "processing_owlvit": ["OwlViTProcessor"],
+}
+
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_owlvit"] = [
+ "OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OwlViTModel",
+ "OwlViTPreTrainedModel",
+ "OwlViTTextModel",
+ "OwlViTVisionModel",
+ "OwlViTForObjectDetection",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_owlvit import (
+ OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ OwlViTConfig,
+ OwlViTTextConfig,
+ OwlViTVisionConfig,
+ )
+ from .processing_owlvit import OwlViTProcessor
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_owlvit import OwlViTFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_owlvit import (
+ OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ OwlViTForObjectDetection,
+ OwlViTModel,
+ OwlViTPreTrainedModel,
+ OwlViTTextModel,
+ OwlViTVisionModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/owlvit/configuration_owlvit.py b/src/transformers/models/owlvit/configuration_owlvit.py
new file mode 100644
index 000000000000..85ffdbadbeff
--- /dev/null
+++ b/src/transformers/models/owlvit/configuration_owlvit.py
@@ -0,0 +1,336 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" OWL-ViT model configuration"""
+
+import copy
+import os
+from typing import Dict, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "google/owlvit-base-patch32": "https://huggingface.co/google/owlvit-base-patch32/resolve/main/config.json",
+ "google/owlvit-base-patch16": "https://huggingface.co/google/owlvit-base-patch16/resolve/main/config.json",
+ "google/owlvit-large-patch14": "https://huggingface.co/google/owlvit-large-patch14/resolve/main/config.json",
+}
+
+
+class OwlViTTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`OwlViTTextModel`]. It is used to instantiate an
+ OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the OwlViT
+ [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 49408):
+ Vocabulary size of the OWL-ViT text model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`OwlViTTextModel`].
+ hidden_size (`int`, *optional*, defaults to 512):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 16):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*,
+ defaults to 1e-5): The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import OwlViTTextConfig, OwlViTTextModel
+
+ >>> # Initializing a OwlViTTextModel with google/owlvit-base-patch32 style configuration
+ >>> configuration = OwlViTTextConfig()
+
+ >>> # Initializing a OwlViTTextConfig from the google/owlvit-base-patch32 style configuration
+ >>> model = OwlViTTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "owlvit_text_model"
+
+ def __init__(
+ self,
+ vocab_size=49408,
+ hidden_size=512,
+ intermediate_size=2048,
+ num_hidden_layers=12,
+ num_attention_heads=8,
+ max_position_embeddings=16,
+ hidden_act="quick_gelu",
+ layer_norm_eps=0.00001,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=0,
+ bos_token_id=49406,
+ eos_token_id=49407,
+ **kwargs
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from OwlViTConfig
+ if config_dict.get("model_type") == "owlvit":
+ config_dict = config_dict["text_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class OwlViTVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`OwlViTVisionModel`]. It is used to instantiate
+ an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the OWL-ViT
+ [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_size (`int`, *optional*, defaults to 768):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*,
+ defaults to 1e-5): The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float``, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import OwlViTVisionConfig, OwlViTVisionModel
+
+ >>> # Initializing a OwlViTVisionModel with google/owlvit-base-patch32 style configuration
+ >>> configuration = OwlViTVisionConfig()
+
+ >>> # Initializing a OwlViTVisionModel model from the google/owlvit-base-patch32 style configuration
+ >>> model = OwlViTVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "owlvit_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ image_size=768,
+ patch_size=32,
+ hidden_act="quick_gelu",
+ layer_norm_eps=0.00001,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from OwlViTConfig
+ if config_dict.get("model_type") == "owlvit":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class OwlViTConfig(PretrainedConfig):
+ r"""
+ [`OwlViTConfig`] is the configuration class to store the configuration of an [`OwlViTModel`]. It is used to
+ instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`OwlViTTextConfig`].
+ vision_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`OwlViTVisionConfig`].
+ projection_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of text and vision projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The inital value of the *logit_scale* parameter. Default is used as per the original OWL-ViT
+ implementation.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ """
+
+ model_type = "owlvit"
+ is_composition = True
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ projection_dim=512,
+ logit_scale_init_value=2.6592,
+ return_dict=True,
+ **kwargs
+ ):
+ super().__init__(text_config=text_config, vision_config=vision_config, **kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("text_config_dict is None. Initializing the OwlViTTextConfig with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("vision_config_dict is None. initializing the OwlViTVisionConfig with default values.")
+
+ self.text_config = OwlViTTextConfig(**text_config)
+ self.vision_config = OwlViTVisionConfig(**vision_config)
+
+ self.projection_dim = projection_dim
+ self.logit_scale_init_value = logit_scale_init_value
+ self.return_dict = return_dict
+ self.initializer_factor = 1.0
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+ @classmethod
+ def from_text_vision_configs(cls, text_config: Dict, vision_config: Dict, **kwargs):
+ r"""
+ Instantiate a [`OwlViTConfig`] (or a derived class) from owlvit text model configuration and owlvit vision
+ model configuration.
+
+ Returns:
+ [`OwlViTConfig`]: An instance of a configuration object
+ """
+ config_dict = {}
+ config_dict["text_config"] = text_config
+ config_dict["vision_config"] = vision_config
+
+ return cls.from_dict(config_dict, **kwargs)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["text_config"] = self.text_config.to_dict()
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py
new file mode 100644
index 000000000000..dde57c168ade
--- /dev/null
+++ b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py
@@ -0,0 +1,407 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert OWL-ViT checkpoints from the original repository. URL:
+https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit"""
+
+import argparse
+import collections
+
+import torch
+import torch.nn as nn
+
+import jax
+import jax.numpy as jnp
+from clip.model import CLIP
+from flax.training import checkpoints
+from huggingface_hub import Repository
+from transformers import (
+ CLIPTokenizer,
+ OwlViTConfig,
+ OwlViTFeatureExtractor,
+ OwlViTForObjectDetection,
+ OwlViTModel,
+ OwlViTProcessor,
+)
+
+
+CONFIGS = {
+ "vit_b32": dict(
+ embed_dim=512,
+ image_resolution=768,
+ context_length=16,
+ vocab_size=49408,
+ vision_layers=12,
+ vision_width=768,
+ vision_patch_size=32,
+ transformer_width=512,
+ transformer_heads=8,
+ transformer_layers=12,
+ ),
+ "vit_b16": dict(
+ embed_dim=512,
+ image_resolution=768,
+ context_length=16,
+ vocab_size=49408,
+ vision_layers=12,
+ vision_width=768,
+ vision_patch_size=16,
+ transformer_width=512,
+ transformer_heads=8,
+ transformer_layers=12,
+ ),
+ "vit_l14": dict(
+ embed_dim=768,
+ image_resolution=840,
+ context_length=16,
+ vocab_size=49408,
+ vision_layers=24,
+ vision_width=1024,
+ vision_patch_size=14,
+ transformer_width=768,
+ transformer_heads=12,
+ transformer_layers=12,
+ ),
+}
+
+
+def flatten_nested_dict(params, parent_key="", sep="/"):
+ items = []
+
+ for k, v in params.items():
+ new_key = parent_key + sep + k if parent_key else k
+
+ if isinstance(v, collections.MutableMapping):
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def to_f32(params):
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
+
+
+def copy_attn_layer(hf_attn_layer, pt_attn_layer):
+ q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
+ q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
+
+ out_proj_weights = pt_attn_layer.out_proj.weight
+ out_proj_bias = pt_attn_layer.out_proj.bias
+
+ hf_attn_layer.q_proj.weight.data = q_proj
+ hf_attn_layer.q_proj.bias.data = q_proj_bias
+
+ hf_attn_layer.k_proj.weight.data = k_proj
+ hf_attn_layer.k_proj.bias.data = k_proj_bias
+
+ hf_attn_layer.v_proj.weight.data = v_proj
+ hf_attn_layer.v_proj.bias.data = v_proj_bias
+
+ hf_attn_layer.out_proj.weight = out_proj_weights
+ hf_attn_layer.out_proj.bias = out_proj_bias
+
+
+def copy_mlp(hf_mlp, pt_mlp):
+ copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
+ copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
+
+
+def copy_linear(hf_linear, pt_linear):
+ hf_linear.weight = pt_linear.weight
+ hf_linear.bias = pt_linear.bias
+
+
+def copy_layer(hf_layer, pt_layer):
+ # copy layer norms
+ copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
+ copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
+
+ # copy MLP
+ copy_mlp(hf_layer.mlp, pt_layer.mlp)
+
+ # copy attn
+ copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
+
+
+def copy_layers(hf_layers, pt_layers):
+ for hf_layer, pt_layer in zip(hf_layers, pt_layers):
+ copy_layer(hf_layer, pt_layer)
+
+
+def copy_encoder(hf_encoder, pt_model):
+ # copy embeds
+ hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
+ hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
+
+ # copy layer norm
+ copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
+
+ # copy hidden layers
+ copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
+
+
+def copy_text_model_and_projection(hf_model, pt_model):
+ # copy projection
+ hf_model.text_projection.weight.data = pt_model.text_projection.data.T
+
+ # copy text encoder
+ copy_encoder(hf_model.text_model, pt_model)
+
+
+def copy_vision_model_and_projection(hf_model, pt_model):
+ # copy projection
+ hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T
+
+ # copy layer norms
+ copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre)
+ copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
+
+ # copy embeds
+ hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
+ hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
+ hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
+
+ # copy encoder
+ copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
+
+
+def copy_class_merge_token(hf_model, flax_params):
+ flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"])
+
+ weight = torch.from_numpy(flax_class_token_params["scale"])
+ bias = torch.from_numpy(flax_class_token_params["bias"])
+ hf_model.layer_norm.weight = nn.Parameter(weight)
+ hf_model.layer_norm.bias = nn.Parameter(bias)
+
+
+def copy_class_box_heads(hf_model, flax_params):
+ pt_params = hf_model.state_dict()
+ new_params = {}
+
+ # Rename class prediction head flax params to pytorch HF
+ flax_class_params = flatten_nested_dict(flax_params["class_head"])
+
+ for flax_key, v in flax_class_params.items():
+ torch_key = flax_key.replace("/", ".")
+ torch_key = torch_key.replace(".kernel", ".weight")
+ torch_key = torch_key.replace("Dense_0", "dense0")
+ torch_key = "class_head." + torch_key
+
+ if "weight" in torch_key and v.ndim == 2:
+ v = v.T
+
+ new_params[torch_key] = nn.Parameter(torch.from_numpy(v))
+
+ # Rename box prediction box flax params to pytorch HF
+ flax_box_params = flatten_nested_dict(flax_params["obj_box_head"])
+
+ for flax_key, v in flax_box_params.items():
+ torch_key = flax_key.replace("/", ".")
+ torch_key = torch_key.replace(".kernel", ".weight")
+ torch_key = torch_key.replace("_", "").lower()
+ torch_key = "box_head." + torch_key
+
+ if "weight" in torch_key and v.ndim == 2:
+ v = v.T
+
+ new_params[torch_key] = nn.Parameter(torch.from_numpy(v))
+
+ # Copy flax params to PyTorch params
+ for name, param in new_params.items():
+ if name in pt_params.keys():
+ pt_params[name].copy_(param)
+
+
+def copy_flax_attn_params(hf_backbone, flax_attn_params):
+ for k, v in flax_attn_params.items():
+ if k.startswith("transformer"):
+ torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers")
+ else:
+ torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers")
+
+ torch_key = torch_key.replace("attn", "self_attn")
+ torch_key = torch_key.replace("key", "k_proj")
+ torch_key = torch_key.replace("value", "v_proj")
+ torch_key = torch_key.replace("query", "q_proj")
+ torch_key = torch_key.replace("out", "out_proj")
+
+ if "bias" in torch_key and v.ndim == 2:
+ shape = v.shape[0] * v.shape[1]
+ v = v.reshape(shape)
+
+ if "weight" in torch_key and "out" in torch_key:
+ shape = (v.shape[0] * v.shape[1], v.shape[2])
+ v = v.reshape(shape).T
+
+ if "weight" in torch_key and "out" not in torch_key:
+ shape = (v.shape[0], v.shape[1] * v.shape[2])
+ v = v.reshape(shape).T
+
+ # Copy flax CLIP attn params to HF PyTorch params
+ v = torch.from_numpy(v)
+ hf_backbone.state_dict()[torch_key].copy_(v)
+
+
+def _convert_attn_layers(params):
+ new_params = {}
+ processed_attn_layers = []
+
+ for k, v in params.items():
+ if "attn." in k:
+ base = k[: k.rindex("attn.") + 5]
+ if base in processed_attn_layers:
+ continue
+
+ processed_attn_layers.append(base)
+ dim = params[base + "out.weight"].shape[-1]
+ new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T
+ new_params[base + "out_proj.bias"] = params[base + "out.bias"]
+ else:
+ new_params[k] = v
+ return new_params
+
+
+def convert_clip_backbone(flax_params, torch_config):
+ torch_model = CLIP(**torch_config)
+ torch_model.eval()
+ torch_clip_params = torch_model.state_dict()
+
+ flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"])
+ new_torch_params = {}
+
+ for flax_key, v in flax_clip_params.items():
+ torch_key = flax_key.replace("/", ".")
+ torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel")
+
+ if (
+ torch_key.startswith("text.transformer")
+ or torch_key.startswith("text.text_projection")
+ or torch_key.startswith("text.ln_final")
+ or torch_key.startswith("text.positional_embedding")
+ ):
+ torch_key = torch_key[5:]
+
+ torch_key = torch_key.replace("text_projection.kernel", "text_projection")
+ torch_key = torch_key.replace("visual.proj.kernel", "visual.proj")
+ torch_key = torch_key.replace(".scale", ".weight")
+ torch_key = torch_key.replace(".kernel", ".weight")
+
+ if "conv" in torch_key or "downsample.0.weight" in torch_key:
+ v = v.transpose(3, 2, 0, 1)
+
+ elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key:
+ # Fully connected layers are transposed, embeddings are not
+ v = v.T
+
+ new_torch_params[torch_key] = v
+
+ attn_params = _convert_attn_layers(new_torch_params)
+ new_torch_params.update(attn_params)
+ attn_params = {}
+
+ # Copy flax CLIP backbone params to PyTorch params
+ for name, param in new_torch_params.items():
+ if name in torch_clip_params.keys():
+
+ new_param = torch.from_numpy(new_torch_params[name])
+ torch_clip_params[name].copy_(new_param)
+ else:
+ attn_params[name] = param
+
+ return torch_clip_params, torch_model, attn_params
+
+
+@torch.no_grad()
+def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}")
+ repo.git_pull()
+
+ if config_path is not None:
+ config = OwlViTConfig.from_pretrained(config_path)
+ else:
+ config = OwlViTConfig()
+
+ hf_backbone = OwlViTModel(config).eval()
+ hf_model = OwlViTForObjectDetection(config).eval()
+
+ copy_text_model_and_projection(hf_backbone, pt_backbone)
+ copy_vision_model_and_projection(hf_backbone, pt_backbone)
+ hf_backbone.logit_scale = pt_backbone.logit_scale
+ copy_flax_attn_params(hf_backbone, attn_params)
+
+ hf_model.owlvit = hf_backbone
+ copy_class_merge_token(hf_model, flax_params)
+ copy_class_box_heads(hf_model, flax_params)
+
+ # Save HF model
+ hf_model.save_pretrained(repo.local_dir)
+
+ # Initialize feature extractor
+ feature_extractor = OwlViTFeatureExtractor(
+ size=config.vision_config.image_size, crop_size=config.vision_config.image_size
+ )
+ # Initialize tokenizer
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16)
+
+ # Initialize processor
+ processor = OwlViTProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+ feature_extractor.save_pretrained(repo.local_dir)
+ processor.save_pretrained(repo.local_dir)
+
+ repo.git_add()
+ repo.git_commit("Upload model and processor")
+ repo.git_push()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--owlvit_version",
+ default=None,
+ type=str,
+ required=True,
+ help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].",
+ )
+ parser.add_argument(
+ "--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint."
+ )
+ parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.")
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model."
+ )
+ args = parser.parse_args()
+
+ # Initialize PyToch clip model
+ model_name = args.owlvit_version
+ if model_name == "clip_b16":
+ torch_config = CONFIGS["vit_b16"]
+ elif model_name == "clip_b32":
+ torch_config = CONFIGS["vit_b32"]
+ elif model_name == "clip_l14":
+ torch_config = CONFIGS["vit_l14"]
+
+ # Load from checkpoint and convert params to float-32
+ variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"]
+ flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
+ del variables
+
+ # Convert CLIP backbone
+ pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config)
+
+ convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config)
diff --git a/src/transformers/models/owlvit/feature_extraction_owlvit.py b/src/transformers/models/owlvit/feature_extraction_owlvit.py
new file mode 100644
index 000000000000..1e4bc735608a
--- /dev/null
+++ b/src/transformers/models/owlvit/feature_extraction_owlvit.py
@@ -0,0 +1,210 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for OwlViT."""
+
+from typing import List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+def center_to_corners_format(x):
+ """
+ Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
+ (left, top, right, bottom).
+ """
+ x_center, y_center, width, height = x.unbind(-1)
+ boxes = [(x_center - 0.5 * width), (y_center - 0.5 * height), (x_center + 0.5 * width), (y_center + 0.5 * height)]
+ return torch.stack(boxes, dim=-1)
+
+
+class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs an OWL-ViT feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the shorter edge of the input to a certain `size`.
+ size (`int`, *optional*, defaults to 768):
+ Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`int`, *optional*, defaults to 768):
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying
+ center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ image_mean (`List[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=768,
+ resample=Image.BICUBIC,
+ crop_size=768,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.size = size
+ self.resample = resample
+ self.crop_size = crop_size
+ self.do_resize = do_resize
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
+
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the output of [`OwlViTForObjectDetection`] into the format expected by the COCO api.
+
+ Args:
+ outputs ([`OwlViTObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
+ image size (before any data augmentation). For visualization, this should be the image size after data
+ augment, but before padding.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ logits, boxes = outputs.logits, outputs.pred_boxes
+
+ if len(logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ probs = torch.max(logits, dim=-1)
+ scores = torch.sigmoid(probs.values)
+ labels = probs.indices
+
+ # Convert to [x0, y0, x1, y1] format
+ boxes = center_to_corners_format(boxes)
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ def __call__(
+ self,
+ images: Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+ ],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W) or (H, W, C),
+ where C is a number of channels, H and W are image height and width.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None and self.resample is not None:
+ images = [
+ self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
+ for image in images
+ ]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [self.center_crop(image, self.crop_size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py
new file mode 100644
index 000000000000..35ebd16cf25b
--- /dev/null
+++ b/src/transformers/models/owlvit/modeling_owlvit.py
@@ -0,0 +1,1396 @@
+# coding=utf-8
+# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch OWL-ViT model."""
+
+
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32"
+
+# See all OwlViT models at https://huggingface.co/models?filter=owlvit
+OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "google/owlvit-base-patch32",
+ "google/owlvit-base-patch16",
+ "google/owlvit-large-patch14",
+]
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
+def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.T)
+ return (caption_loss + image_loss) / 2.0
+
+
+@dataclass
+class OwlViTOutput(ModelOutput):
+ """
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of
+ [`OwlViTVisionModel`].
+ text_model_output (Tuple[`BaseModelOutputWithPooling`]):
+ The output of the [`OwlViTTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`OwlViTVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: torch.FloatTensor = None
+ logits_per_text: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ image_embeds: torch.FloatTensor = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+@dataclass
+class OwlViTObjectDetectionOutput(ModelOutput):
+ """
+ Output type of [`OwlViTForObjectDetection`].
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
+ bounding boxes.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
+ Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
+ image embeddings for each patch.
+ class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
+ Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
+ number of patches is (image_size / patch_size)**2.
+ text_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)):
+ Last hidden states extracted from the [`OwlViTTextModel`].
+ vision_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)):
+ Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image
+ patches where the total number of patches is (image_size / patch_size)**2.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[Dict] = None
+ logits: torch.FloatTensor = None
+ pred_boxes: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ image_embeds: torch.FloatTensor = None
+ class_embeds: torch.FloatTensor = None
+ text_model_last_hidden_states: Optional[torch.FloatTensor] = None
+ vision_model_last_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class OwlViTVisionEmbeddings(nn.Module):
+ def __init__(self, config: OwlViTVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (config.image_size // config.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+
+ return embeddings
+
+
+class OwlViTTextEmbeddings(nn.Module):
+ def __init__(self, config: OwlViTTextConfig):
+ super().__init__()
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class OwlViTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT
+class OwlViTMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT
+class OwlViTEncoderLayer(nn.Module):
+ def __init__(self, config: OwlViTConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = OwlViTAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim)
+ self.mlp = OwlViTMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class OwlViTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = OwlViTConfig
+ base_model_prefix = "owlvit"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, OwlViTTextEmbeddings):
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ elif isinstance(module, OwlViTVisionEmbeddings):
+ factor = self.config.initializer_factor
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
+ elif isinstance(module, OwlViTAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, OwlViTMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ )
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+ elif isinstance(module, OwlViTModel):
+ nn.init.normal_(
+ module.text_projection.weight,
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
+ )
+ nn.init.normal_(
+ module.visual_projection.weight,
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
+ )
+ if isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, OwlViTEncoder):
+ module.gradient_checkpointing = value
+
+
+OWLVIT_START_DOCSTRING = r"""
+ Parameters:
+ This model is a PyTorch [torch.nn.Module](https:
+ //pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+ config ([`OwlViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+OWLVIT_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+OWLVIT_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+OWLVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+"""
+
+
+class OwlViTEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`OwlViTEncoderLayer`].
+
+ Args:
+ config: OwlViTConfig
+ """
+
+ def __init__(self, config: OwlViTConfig):
+ super().__init__()
+ self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class OwlViTTextTransformer(nn.Module):
+ def __init__(self, config: OwlViTTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = OwlViTTextEmbeddings(config)
+ self.encoder = OwlViTEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig)
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ num_samples, seq_len = input_shape # num_samples = batch_size * num_max_text_queries
+ # OWLVIT's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = self._build_causal_attention_mask(num_samples, seq_len).to(hidden_states.device)
+ # expand attention_mask
+ if attention_mask is not None:
+ # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # take features from the end of tokens embedding (end of token is the highest number in each sequence)
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def _build_causal_attention_mask(self, bsz, seq_len):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(bsz, seq_len, seq_len)
+ mask.fill_(torch.tensor(float("-inf")))
+ mask.triu_(1) # zero out the lower diagonal
+ mask = mask.unsqueeze(1) # expand mask
+ return mask
+
+
+class OwlViTTextModel(OwlViTPreTrainedModel):
+ config_class = OwlViTTextConfig
+
+ def __init__(self, config: OwlViTTextConfig):
+ super().__init__(config)
+ self.text_model = OwlViTTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig)
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import OwlViTProcessor, OwlViTTextModel
+
+ >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32")
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> inputs = processor(
+ ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
+ ... )
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ # Get embeddings for all text queries in all batch samples
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class OwlViTVisionTransformer(nn.Module):
+ def __init__(self, config: OwlViTVisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = OwlViTVisionEmbeddings(config)
+ self.pre_layernorm = nn.LayerNorm(config.hidden_size)
+ self.encoder = OwlViTEncoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
+
+ @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_hidden_state: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layernorm(hidden_states)
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+
+ if use_hidden_state:
+ pooled_output = self.post_layernorm(last_hidden_state)
+ else:
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class OwlViTVisionModel(OwlViTPreTrainedModel):
+ config_class = OwlViTVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: OwlViTVisionConfig):
+ super().__init__(config)
+ self.vision_model = OwlViTVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import OwlViTProcessor, OwlViTVisionModel
+
+ >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32")
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(OWLVIT_START_DOCSTRING)
+class OwlViTModel(OwlViTPreTrainedModel):
+ config_class = OwlViTConfig
+
+ def __init__(self, config: OwlViTConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, OwlViTTextConfig):
+ raise ValueError(
+ "config.text_config is expected to be of type OwlViTTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, OwlViTVisionConfig):
+ raise ValueError(
+ "config.vision_config is expected to be of type OwlViTVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = OwlViTTextTransformer(text_config)
+ self.vision_model = OwlViTVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
+ self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`OwlViTTextModel`].
+
+ Examples:
+ ```python
+ >>> from transformers import OwlViTProcessor, OwlViTModel
+
+ >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> inputs = processor(
+ ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
+ ... )
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Get embeddings for all text queries in all batch samples
+ text_output = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_output[1]
+ text_features = self.text_projection(pooled_output)
+ return text_features
+
+ @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_projected: Optional[bool] = True,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`OwlViTVisionModel`].
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import OwlViTProcessor, OwlViTModel
+
+ >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+ >>> image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+
+ # Return projected output
+ if return_projected:
+ image_features = self.visual_projection(pooled_output)
+ else:
+ image_features = pooled_output
+ return image_features
+
+ @add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, OwlViTOutput]:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import OwlViTProcessor, OwlViTModel
+
+ >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_hidden_state=False,
+ )
+
+ # Get embeddings for all text queries in all batch samples
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.T
+
+ loss = None
+ if return_loss:
+ loss = owlvit_loss(logits_per_text)
+
+ if not return_dict:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return OwlViTOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+class OwlViTBoxPredictionHead(nn.Module):
+ def __init__(self, config: OwlViTConfig):
+ super().__init__()
+
+ width = config.vision_config.hidden_size
+ self.dense0 = nn.Linear(width, width)
+ self.dense1 = nn.Linear(width, width)
+ self.gelu = nn.GELU()
+ self.dense2 = nn.Linear(width, 4)
+
+ def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
+ output = self.dense0(image_features)
+ output = self.gelu(output)
+ output = self.dense1(output)
+ output = self.gelu(output)
+ output = self.dense2(output)
+ return output
+
+
+class OwlViTClassPredictionHead(nn.Module):
+ def __init__(self, config: OwlViTConfig):
+ super().__init__()
+
+ out_dim = config.text_config.hidden_size
+ query_dim = config.vision_config.hidden_size
+
+ self.dense0 = nn.Linear(query_dim, out_dim)
+ self.logit_shift = nn.Linear(query_dim, 1)
+ self.logit_scale = nn.Linear(query_dim, 1)
+ self.elu = nn.ELU()
+
+ def forward(
+ self,
+ image_embeds: torch.FloatTensor,
+ query_embeds: torch.FloatTensor,
+ query_mask: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor]:
+
+ image_class_embeds = self.dense0(image_embeds)
+
+ # Normalize image and text features
+ image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
+ query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
+
+ # Get class predictions
+ pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
+
+ # Apply a learnable shift and scale to logits
+ logit_shift = self.logit_shift(image_embeds)
+ logit_scale = self.logit_scale(image_embeds)
+ logit_scale = self.elu(logit_scale) + 1
+ pred_logits = (pred_logits + logit_shift) * logit_scale
+
+ if query_mask is not None:
+ if query_mask.ndim > 1:
+ query_mask = torch.unsqueeze(query_mask, dim=-2)
+
+ pred_logits = pred_logits.to(torch.float64)
+ pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
+ pred_logits = pred_logits.to(torch.float32)
+
+ return (pred_logits, image_class_embeds)
+
+
+class OwlViTForObjectDetection(OwlViTPreTrainedModel):
+ config_class = OwlViTConfig
+
+ def __init__(self, config: OwlViTConfig):
+ super().__init__(config)
+
+ self.owlvit = OwlViTModel(config)
+ self.class_head = OwlViTClassPredictionHead(config)
+ self.box_head = OwlViTBoxPredictionHead(config)
+
+ self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size)
+ self.sigmoid = nn.Sigmoid()
+
+ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
+ # Computes normalized xy corner coordinates from feature_map.
+ if not feature_map.ndim == 4:
+ raise ValueError("Expected input shape is [batch_size, num_channels, height, width]")
+
+ device = feature_map.device
+ height, width = feature_map.shape[1:3]
+
+ box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype(
+ np.float32
+ )
+ box_coordinates /= np.array([width, height], np.float32)
+
+ # Flatten (h, w, 2) -> (h*w, 2)
+ box_coordinates = box_coordinates.reshape(
+ box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
+ )
+ box_coordinates = torch.from_numpy(box_coordinates).to(device)
+
+ return box_coordinates
+
+ def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
+ # The box center is biased to its position on the feature grid
+ box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
+ box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
+
+ # Unnormalize xy
+ box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
+
+ # The box size is biased to the patch size
+ box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
+ box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
+
+ # Compute box bias
+ box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
+ return box_bias
+
+ def box_predictor(
+ self,
+ image_feats: torch.FloatTensor,
+ feature_map: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ image_feats:
+ Features extracted from the image, returned by the `image_text_embedder` method.
+ feature_map:
+ A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
+ Returns:
+ pred_boxes:
+ List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
+ """
+ # Bounding box detection head [batch_size, num_boxes, 4].
+ pred_boxes = self.box_head(image_feats)
+
+ # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
+ pred_boxes += self.compute_box_bias(feature_map)
+ pred_boxes = self.sigmoid(pred_boxes)
+ return pred_boxes
+
+ def class_predictor(
+ self,
+ image_feats: torch.FloatTensor,
+ query_embeds: torch.FloatTensor,
+ query_mask: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ """
+ Args:
+ image_feats:
+ Features extracted from the `image_text_embedder`.
+ query_embeds:
+ Text query embeddings.
+ query_mask:
+ Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
+ """
+ (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
+
+ return (pred_logits, image_class_embeds)
+
+ def image_text_embedder(
+ self,
+ input_ids: torch.Tensor,
+ pixel_values: torch.FloatTensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ # Encode text
+ text_embeds = self.owlvit.get_text_features(
+ input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+
+ # Encode image
+ image_embeds = self.owlvit.get_image_features(
+ pixel_values, return_projected=False, output_attentions=output_attentions
+ )
+
+ # Resize class token
+ new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
+ class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
+
+ # Merge image embedding with class tokens
+ image_embeds = image_embeds[:, 1:, :] * class_token_out
+ image_embeds = self.layer_norm(image_embeds)
+
+ # Resize to [batch_size, num_patches, num_patches, hidden_size]
+ new_size = (
+ image_embeds.shape[0],
+ int(np.sqrt(image_embeds.shape[1])),
+ int(np.sqrt(image_embeds.shape[1])),
+ image_embeds.shape[-1],
+ )
+ image_embeds = image_embeds.reshape(new_size)
+
+ return (image_embeds, text_embeds)
+
+ @add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ pixel_values: torch.FloatTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> OwlViTObjectDetectionOutput:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> import requests
+ >>> from PIL import Image
+ >>> import torch
+ >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection
+
+ >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
+ >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> texts = [["a photo of a cat", "a photo of a dog"]]
+ >>> inputs = processor(text=texts, images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
+ >>> target_sizes = torch.Tensor([image.size[::-1]])
+ >>> # Convert outputs (bounding boxes and class logits) to COCO API
+ >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
+
+ >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
+ >>> text = texts[i]
+ >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
+
+ >>> score_threshold = 0.1
+ >>> for box, score, label in zip(boxes, scores, labels):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... if score >= score_threshold:
+ ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
+ Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
+ Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # Return last hidden states of text and vision transformers
+ text_model_last_hidden_states = None
+ vision_model_last_hidden_states = None
+
+ if output_hidden_states:
+ outputs = self.owlvit(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ text_model_last_hidden_states = outputs[-2][0]
+ vision_model_last_hidden_states = outputs[-1][0]
+
+ # Embed images and text queries
+ feature_map, query_embeds = self.image_text_embedder(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ batch_size, height, width, hidden_dim = feature_map.shape
+ image_feats = torch.reshape(feature_map, (batch_size, height * width, hidden_dim))
+
+ # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
+ max_text_queries = input_ids.shape[0] // batch_size
+ query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
+
+ # If first token is 0, then this is a padded query [batch_size, num_queries].
+ input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
+ query_mask = input_ids[..., 0] > 0
+
+ # Predict object classes [batch_size, num_patches, num_queries+1]
+ (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
+
+ # Predict object boxes
+ pred_boxes = self.box_predictor(image_feats, feature_map)
+
+ if not return_dict:
+ output = (
+ pred_logits,
+ pred_boxes,
+ query_embeds,
+ feature_map,
+ class_embeds,
+ text_model_last_hidden_states,
+ vision_model_last_hidden_states,
+ )
+ output = tuple(x for x in output if x is not None)
+ return output
+
+ return OwlViTObjectDetectionOutput(
+ image_embeds=feature_map,
+ text_embeds=query_embeds,
+ pred_boxes=pred_boxes,
+ logits=pred_logits,
+ class_embeds=class_embeds,
+ text_model_last_hidden_states=text_model_last_hidden_states,
+ vision_model_last_hidden_states=vision_model_last_hidden_states,
+ )
diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py
new file mode 100644
index 000000000000..48060f0dcf64
--- /dev/null
+++ b/src/transformers/models/owlvit/processing_owlvit.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for OWL-ViT
+"""
+from typing import List
+
+import numpy as np
+
+from transformers import is_flax_available, is_tf_available, is_torch_available
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding
+
+
+class OwlViTProcessor(ProcessorMixin):
+ r"""
+ Constructs an OWL-ViT processor which wraps [`OwlViTFeatureExtractor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`]
+ into a single processor that interits both the feature extractor and tokenizer functionalities. See the
+ [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor ([`OwlViTFeatureExtractor`]):
+ The feature extractor is a required input.
+ tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
+ The tokenizer is a required input.
+ """
+ feature_extractor_class = "OwlViTFeatureExtractor"
+ tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+
+ def __call__(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
+ """
+ Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
+ `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
+ doctsring of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
+ `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ Returns:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+
+ if text is None and images is None:
+ raise ValueError("You have to specify at least one text or image. Both cannot be none.")
+
+ if text is not None:
+ if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
+ encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
+
+ elif isinstance(text, List) and isinstance(text[0], List):
+ encodings = []
+
+ # Maximum number of queries across batch
+ max_num_queries = max([len(t) for t in text])
+
+ # Pad all batch samples to max number of text queries
+ for t in text:
+ if len(t) != max_num_queries:
+ t = t + [" "] * (max_num_queries - len(t))
+
+ encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
+ encodings.append(encoding)
+ else:
+ raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
+
+ if return_tensors == "np":
+ input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ elif return_tensors == "jax" and is_flax_available():
+ import jax.numpy as jnp
+
+ input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ elif return_tensors == "pt" and is_torch_available():
+ import torch
+
+ input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
+ attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
+
+ elif return_tensors == "tf" and is_tf_available():
+ import tensorflow as tf
+
+ input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
+ attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
+
+ else:
+ raise ValueError("Target return tensor type could not be returned")
+
+ encoding = BatchEncoding()
+ encoding["input_ids"] = input_ids
+ encoding["attention_mask"] = attention_mask
+
+ if images is not None:
+ image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
+
+ if text is not None and images is not None:
+ encoding["pixel_values"] = image_features.pixel_values
+ return encoding
+ elif text is not None:
+ return encoding
+ else:
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
+
+ def post_process(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to [`OwlViTFeatureExtractor.post_process`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.feature_extractor.post_process(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py
index 4d01c31c6df2..ca04afeeb1a0 100644
--- a/src/transformers/models/pegasus/__init__.py
+++ b/src/transformers/models/pegasus/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"],
-}
+_import_structure = {"configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_pegasus"] = ["PegasusTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_pegasus_fast"] = ["PegasusTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_pegasus"] = [
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
"PegasusForCausalLM",
@@ -46,14 +60,24 @@
"PegasusPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_pegasus"] = [
"FlaxPegasusForConditionalGeneration",
"FlaxPegasusModel",
@@ -64,13 +88,28 @@
if TYPE_CHECKING:
from .configuration_pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_pegasus import PegasusTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_pegasus_fast import PegasusTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_pegasus import (
PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
PegasusForCausalLM,
@@ -79,10 +118,20 @@
PegasusPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_pegasus import (
FlaxPegasusForConditionalGeneration,
FlaxPegasusModel,
diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py
index 81276dcd2adc..303d0055716c 100644
--- a/src/transformers/models/pegasus/modeling_flax_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -544,7 +544,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 2f79fa93fe5a..5a144aa3e9c5 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -775,7 +778,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -872,11 +876,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
@@ -1043,7 +1049,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -1285,10 +1292,10 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
- r"embed_positions\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
+ r"embed_positions.weight",
]
def __init__(self, config: PegasusConfig):
diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py
index be2539b3a910..85df859c8479 100644
--- a/src/transformers/models/pegasus/modeling_tf_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py
@@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
- bsz, tgt_len = input_ids_shape
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -163,12 +164,14 @@ def _init_weight(n_pos: int, dim: int):
tf.stop_gradient(table)
return table
- def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
+ def call(
+ self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
+ ):
"""Input is expected to be of size [bsz x seqlen]."""
- bsz, seq_len = input_shape[:2]
-
- positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
- return tf.gather(self.weight, positions)
+ if position_ids is None:
+ seq_len = input_shape[1]
+ position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
+ return tf.gather(self.weight, position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
@@ -268,7 +271,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -278,7 +284,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -294,7 +303,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -311,7 +323,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -615,6 +630,9 @@ def serving(self, inputs):
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+ decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@@ -787,7 +805,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -861,6 +882,7 @@ def call(
input_ids=None,
inputs_embeds=None,
attention_mask=None,
+ position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@@ -889,6 +911,9 @@ def call(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@@ -918,11 +943,11 @@ def call(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
- `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
- control over how to convert `input_ids` indices into associated vectors than the model's internal
- embedding lookup matrix.
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids`
+ you can choose to directly pass an embedded representation. This is useful if you want more control
+ over how to convert `input_ids` indices into associated vectors than the model's internal embedding
+ lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
@@ -951,7 +976,10 @@ def call(
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
- positions = self.embed_positions(input_shape, past_key_values_length)
+ if position_ids is None:
+ positions = self.embed_positions(input_shape, past_key_values_length)
+ else:
+ positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@@ -989,7 +1017,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1081,6 +1112,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1128,6 +1160,7 @@ def call(
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@@ -1186,6 +1219,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1206,6 +1240,7 @@ def call(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1257,7 +1292,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFPegasusMainLayer(config, name="model")
self.use_cache = config.use_cache
- # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
@@ -1290,6 +1325,7 @@ def call(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
+ decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1317,7 +1353,7 @@ def call(
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
- tf.fill(shape_list(labels), -100),
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False
@@ -1332,6 +1368,7 @@ def call(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@@ -1389,6 +1426,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
+ decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -1396,16 +1434,26 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs
):
+
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
+ if decoder_attention_mask is not None: # xla
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+ elif past is not None: # no xla + past
+ decoder_position_ids = past[0][0].shape[2]
+ else: # no xla + no past
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py
index a6a9167e66de..b4d1cdc19804 100644
--- a/src/transformers/models/pegasus/tokenization_pegasus.py
+++ b/src/transformers/models/pegasus/tokenization_pegasus.py
@@ -119,7 +119,8 @@ def __init__(
if additional_special_tokens is not None:
if not isinstance(additional_special_tokens, list):
raise TypeError(
- f"additional_special_tokens should be of type {type(list)}, but is {type(additional_special_tokens)}"
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
)
additional_special_tokens_extended = (
@@ -134,7 +135,8 @@ def __init__(
if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
raise ValueError(
- f"Please make sure that the provided additional_special_tokens do not contain an incorrectly shifted list of tokens. Found {additional_special_tokens_extended}."
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
)
additional_special_tokens = additional_special_tokens_extended
else:
diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py
index 14399988f0fa..22c6018385f6 100644
--- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py
+++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py
@@ -115,7 +115,8 @@ def __init__(
if additional_special_tokens is not None:
if not isinstance(additional_special_tokens, list):
raise TypeError(
- f"additional_special_tokens should be of type {type(list)}, but is {type(additional_special_tokens)}"
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
)
additional_special_tokens_extended = (
@@ -130,7 +131,8 @@ def __init__(
if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
raise ValueError(
- f"Please make sure that the provided additional_special_tokens do not contain an incorrectly shifted list of tokens. Found {additional_special_tokens_extended}."
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
)
additional_special_tokens = additional_special_tokens_extended
else:
@@ -158,7 +160,8 @@ def _special_token_mask(self, seq):
if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
raise ValueError(
- f"There should be 3 special tokens: mask_token, pad_token, and eos_token + {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
+ "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
+ f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
)
return [1 if x in all_special_ids else 0 for x in seq]
diff --git a/src/transformers/models/perceiver/__init__.py b/src/transformers/models/perceiver/__init__.py
index b20818306434..107c62f2eb8a 100644
--- a/src/transformers/models/perceiver/__init__.py
+++ b/src/transformers/models/perceiver/__init__.py
@@ -17,18 +17,34 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
_import_structure = {
- "configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig"],
+ "configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverOnnxConfig"],
"tokenization_perceiver": ["PerceiverTokenizer"],
}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_perceiver"] = ["PerceiverFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_perceiver"] = [
"PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PerceiverForImageClassificationConvProcessing",
@@ -45,13 +61,23 @@
if TYPE_CHECKING:
- from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig
+ from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverOnnxConfig
from .tokenization_perceiver import PerceiverTokenizer
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_perceiver import PerceiverFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_perceiver import (
PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST,
PerceiverForImageClassificationConvProcessing,
diff --git a/src/transformers/models/perceiver/configuration_perceiver.py b/src/transformers/models/perceiver/configuration_perceiver.py
index fdf1f0124350..0c97974441c5 100644
--- a/src/transformers/models/perceiver/configuration_perceiver.py
+++ b/src/transformers/models/perceiver/configuration_perceiver.py
@@ -14,8 +14,15 @@
# limitations under the License.
""" Perceiver model configuration"""
+from collections import OrderedDict
+from typing import Any, Mapping, Optional, Union
+
from ...configuration_utils import PretrainedConfig
-from ...utils import logging
+from ...feature_extraction_utils import FeatureExtractionMixin
+from ...onnx import OnnxConfig
+from ...onnx.utils import compute_effective_axis_dimension
+from ...tokenization_utils_base import PreTrainedTokenizerBase
+from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
@@ -172,3 +179,63 @@ def __init__(
self.audio_samples_per_frame = audio_samples_per_frame
self.samples_per_patch = samples_per_patch
self.output_shape = output_shape
+
+
+class PerceiverOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("inputs", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+ def generate_dummy_inputs(
+ self,
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
+ batch_size: int = -1,
+ seq_length: int = -1,
+ num_choices: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ num_channels: int = 3,
+ image_width: int = 40,
+ image_height: int = 40,
+ ) -> Mapping[str, Any]:
+ # copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified
+
+ if isinstance(preprocessor, PreTrainedTokenizerBase):
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(
+ batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
+ )
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
+ token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
+ seq_length = compute_effective_axis_dimension(
+ seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
+ )
+ # Generate dummy inputs according to compute batch and sequence
+ dummy_input = [" ".join(["a"]) * seq_length] * batch_size
+ inputs = dict(preprocessor(dummy_input, return_tensors=framework))
+ inputs["inputs"] = inputs.pop("input_ids")
+ return inputs
+ elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
+ dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
+ inputs = dict(preprocessor(images=dummy_input, return_tensors=framework))
+ inputs["inputs"] = inputs.pop("pixel_values")
+ return inputs
+ else:
+ raise ValueError(
+ "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
+ )
diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py
index 6dc1563b47f0..b3a0beea3d3c 100755
--- a/src/transformers/models/perceiver/modeling_perceiver.py
+++ b/src/transformers/models/perceiver/modeling_perceiver.py
@@ -864,8 +864,8 @@ def forward(
inputs_without_pos = None
if inputs.size()[-1] != self.config.d_model:
raise ValueError(
- f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
- "Make sure to set config.d_model appropriately."
+ f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
+ f" {self.config.d_model}. Make sure to set config.d_model appropriately."
)
batch_size, seq_length, _ = inputs.size()
@@ -2181,7 +2181,7 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su
if self.concat_preprocessed_input:
if inputs_without_pos is None:
raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
- pos_emb = torch.cat([inputs_without_pos, pos_emb], div=-1)
+ pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)
return pos_emb
@@ -2735,7 +2735,9 @@ def _check_or_build_spatial_positions(pos, index_dims, batch_size):
"""
if pos is None:
pos = build_linear_positions(index_dims)
- pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape)
+ # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
+ # but `torch.broadcast_to` cannot be converted to ONNX
+ pos = pos[None].expand((batch_size,) + pos.shape)
pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
else:
# Just a warning label: you probably don't want your spatial features to
@@ -2840,7 +2842,8 @@ def __init__(self, config: PerceiverConfig) -> None:
def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = hidden_states.shape
- output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
+ # Flatten batch dim
+ output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
output = output + self.bias
return output.reshape([batch_size, seq_len, self.vocab_size])
@@ -3166,9 +3169,9 @@ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, netw
if self.prep_type != "patches":
# move channels to last dimension, as the _build_network_inputs method below expects this
if inputs.ndim == 4:
- inputs = torch.moveaxis(inputs, 1, -1)
+ inputs = torch.permute(inputs, (0, 2, 3, 1))
elif inputs.ndim == 5:
- inputs = torch.moveaxis(inputs, 2, -1)
+ inputs = torch.permute(inputs, (0, 1, 3, 4, 2))
else:
raise ValueError("Unsupported data format for conv1x1.")
diff --git a/src/transformers/models/phobert/__init__.py b/src/transformers/models/phobert/__init__.py
index 0f226f537aa9..0d9a6f4cea1a 100644
--- a/src/transformers/models/phobert/__init__.py
+++ b/src/transformers/models/phobert/__init__.py
@@ -21,9 +21,7 @@
from ...utils import _LazyModule
-_import_structure = {
- "tokenization_phobert": ["PhobertTokenizer"],
-}
+_import_structure = {"tokenization_phobert": ["PhobertTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/plbart/__init__.py b/src/transformers/models/plbart/__init__.py
index 676feeb39cfa..06204a8901e9 100644
--- a/src/transformers/models/plbart/__init__.py
+++ b/src/transformers/models/plbart/__init__.py
@@ -17,17 +17,31 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
-}
+_import_structure = {"configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_plbart"] = ["PLBartTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_plbart"] = [
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"PLBartForCausalLM",
@@ -41,10 +55,20 @@
if TYPE_CHECKING:
from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_plbart import PLBartTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_plbart import (
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
PLBartForCausalLM,
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index 97e3ec680cbf..d03ddf33ebfa 100755
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -784,7 +787,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
@@ -879,11 +883,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
@@ -1022,7 +1028,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1230,9 +1237,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: PLBartConfig):
diff --git a/src/transformers/models/plbart/tokenization_plbart.py b/src/transformers/models/plbart/tokenization_plbart.py
index 4c302e8b62ce..f6f393f9b8bd 100644
--- a/src/transformers/models/plbart/tokenization_plbart.py
+++ b/src/transformers/models/plbart/tokenization_plbart.py
@@ -14,7 +14,6 @@
# limitations under the License.
import os
-from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
@@ -33,19 +32,41 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-c-cpp-defect-detection": "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-c-cpp-defect-detection": (
+ "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model"
+ ),
"uclanlp/plbart-cs-java": "https://huggingface.co/uclanlp/plbart-cs-java/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-en_XX-java": "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-go-en_XX": "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-java-clone-detection": "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-en_XX-java": (
+ "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-go-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-java-clone-detection": (
+ "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model"
+ ),
"uclanlp/plbart-java-cs": "https://huggingface.co/uclanlp/plbart-java-cs/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-java-en_XX": "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-javascript-en_XX": "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-php-en_XX": "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-python-en_XX": "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-refine-java-medium": "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-refine-java-small": "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-ruby-en_XX": "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-java-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-javascript-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-php-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-python-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-refine-java-medium": (
+ "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-refine-java-small": (
+ "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-ruby-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
@@ -79,8 +100,8 @@ class PLBartTokenizer(PreTrainedTokenizer):
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece).
- The tokenization method is ` ` for source language documents, and ``
- ``` for target language documents.
+ The tokenization method is ` ` for source language documents, and `
+ ` for target language documents.
Args:
vocab_file (`str`):
@@ -131,10 +152,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> expected_translation_english = "Returns the maximum value of a b c."
- >>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... labels = tokenizer(expected_translation_english, return_tensors="pt")
- >>> inputs["labels"] = labels["input_ids"]
+ >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
```"""
vocab_files_names = VOCAB_FILES_NAMES
@@ -419,15 +437,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.set_tgt_lang_special_tokens(self.tgt_lang)
- yield
- self.set_src_lang_special_tokens(self.src_lang)
+ def _switch_to_input_mode(self):
+ return self.set_src_lang_special_tokens(self.src_lang)
+
+ def _switch_to_target_mode(self):
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py
index 799752067fda..7cb5e4acacb9 100644
--- a/src/transformers/models/poolformer/__init__.py
+++ b/src/transformers/models/poolformer/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"],
-}
+_import_structure = {"configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_poolformer"] = ["PoolFormerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_poolformer"] = [
"POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PoolFormerForImageClassification",
@@ -40,10 +48,20 @@
if TYPE_CHECKING:
from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_poolformer import PoolFormerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_poolformer import (
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
PoolFormerForImageClassification,
diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py
index 2335c7cdc40c..b53c482da47b 100755
--- a/src/transformers/models/poolformer/modeling_poolformer.py
+++ b/src/transformers/models/poolformer/modeling_poolformer.py
@@ -50,40 +50,41 @@
]
-# Copied from transformers.models.vit.modeling_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is
- misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion:
- https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
- argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
"""
if drop_prob == 0.0 or not training:
- return x
+ return input
keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
+ output = input.div(keep_prob) * random_tensor
return output
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer
class PoolFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None):
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
class PoolFormerEmbeddings(nn.Module):
"""
@@ -92,17 +93,17 @@ class PoolFormerEmbeddings(nn.Module):
def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
super().__init__()
- patch_size = to_2tuple(patch_size)
- stride = to_2tuple(stride)
- padding = to_2tuple(padding)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
+ padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()
def forward(self, pixel_values):
- x = self.projection(pixel_values)
- x = self.norm(x)
- return x
+ embeddings = self.projection(pixel_values)
+ embeddings = self.norm(embeddings)
+ return embeddings
class PoolFormerGroupNorm(nn.GroupNorm):
diff --git a/src/transformers/models/prophetnet/__init__.py b/src/transformers/models/prophetnet/__init__.py
index be4baf4a16f1..b739fb9f5d5a 100644
--- a/src/transformers/models/prophetnet/__init__.py
+++ b/src/transformers/models/prophetnet/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_prophetnet": ["ProphetNetTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_prophetnet"] = [
"PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ProphetNetDecoder",
@@ -42,7 +47,12 @@
from .configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from .tokenization_prophetnet import ProphetNetTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,
diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py
index 9a6574c84d2b..40f5939d99bc 100644
--- a/src/transformers/models/prophetnet/configuration_prophetnet.py
+++ b/src/transformers/models/prophetnet/configuration_prophetnet.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" ProphetNet model configuration"""
+from typing import Callable, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@@ -22,7 +23,9 @@
logger = logging.get_logger(__name__)
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/prophetnet-large-uncased": "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
+ "microsoft/prophetnet-large-uncased": (
+ "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json"
+ ),
}
@@ -103,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig):
def __init__(
self,
- activation_dropout=0.1,
- activation_function="gelu",
- vocab_size=30522,
- hidden_size=1024,
- encoder_ffn_dim=4096,
- num_encoder_layers=12,
- num_encoder_attention_heads=16,
- decoder_ffn_dim=4096,
- num_decoder_layers=12,
- num_decoder_attention_heads=16,
- attention_dropout=0.1,
- dropout=0.1,
- max_position_embeddings=512,
- init_std=0.02,
- is_encoder_decoder=True,
- add_cross_attention=True,
- decoder_start_token_id=0,
- ngram=2,
- num_buckets=32,
- relative_max_distance=128,
- disable_ngram_loss=False,
- eps=0.0,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
+ activation_dropout: Optional[float] = 0.1,
+ activation_function: Optional[Union[str, Callable]] = "gelu",
+ vocab_size: Optional[int] = 30522,
+ hidden_size: Optional[int] = 1024,
+ encoder_ffn_dim: Optional[int] = 4096,
+ num_encoder_layers: Optional[int] = 12,
+ num_encoder_attention_heads: Optional[int] = 16,
+ decoder_ffn_dim: Optional[int] = 4096,
+ num_decoder_layers: Optional[int] = 12,
+ num_decoder_attention_heads: Optional[int] = 16,
+ attention_dropout: Optional[float] = 0.1,
+ dropout: Optional[float] = 0.1,
+ max_position_embeddings: Optional[int] = 512,
+ init_std: Optional[float] = 0.02,
+ is_encoder_decoder: Optional[bool] = True,
+ add_cross_attention: Optional[bool] = True,
+ decoder_start_token_id: Optional[int] = 0,
+ ngram: Optional[int] = 2,
+ num_buckets: Optional[int] = 32,
+ relative_max_distance: Optional[int] = 128,
+ disable_ngram_loss: Optional[bool] = False,
+ eps: Optional[float] = 0.0,
+ use_cache: Optional[bool] = True,
+ pad_token_id: Optional[int] = 0,
+ bos_token_id: Optional[int] = 1,
+ eos_token_id: Optional[int] = 2,
**kwargs
):
self.vocab_size = vocab_size
@@ -174,5 +177,6 @@ def num_hidden_layers(self) -> int:
@num_hidden_layers.setter
def num_hidden_layers(self, value):
raise NotImplementedError(
- "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and `num_decoder_layers`."
+ "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and"
+ " `num_decoder_layers`."
)
diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py
index c869d6373bc5..537d1f80d448 100644
--- a/src/transformers/models/prophetnet/modeling_prophetnet.py
+++ b/src/transformers/models/prophetnet/modeling_prophetnet.py
@@ -187,7 +187,9 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
"""
This function computes the bias for the predict stream
"""
- left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
+ left_block = (
+ torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
+ )
right_block = left_block.detach().clone()
# create bias
for stream_idx in range(ngram):
@@ -326,7 +328,8 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
@property
def decoder_cross_attentions(self):
warnings.warn(
- "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
FutureWarning,
)
return self.cross_attentions
@@ -344,7 +347,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`):
+ last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
@@ -411,7 +414,8 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
@property
def decoder_cross_attentions(self):
warnings.warn(
- "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
FutureWarning,
)
return self.cross_attentions
@@ -562,9 +566,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the pad_token_id. See ProphetNet docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the"
+ " pad_token_id. See ProphetNet docs for more information"
+ )
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
@@ -587,7 +592,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
the forward function.
"""
- def __init__(self, config: ProphetNetConfig):
+ def __init__(self, config: ProphetNetConfig) -> None:
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
@@ -639,9 +644,10 @@ def __init__(
self.num_attn_heads = num_attn_heads
self.head_dim = hidden_size // num_attn_heads
- assert (
- self.head_dim * num_attn_heads == hidden_size
- ), "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and `config.num_decoder_attention_heads`"
+ assert self.head_dim * num_attn_heads == hidden_size, (
+ "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
+ " `config.num_decoder_attention_heads`"
+ )
self.key_proj = nn.Linear(hidden_size, hidden_size)
self.value_proj = nn.Linear(hidden_size, hidden_size)
@@ -708,7 +714,10 @@ def forward(
batch_size * self.num_attn_heads,
tgt_len,
src_len,
- ), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}"
+ ), (
+ f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
+ f" {attn_weights.shape}"
+ )
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0:
@@ -717,7 +726,10 @@ def forward(
self.num_attn_heads * batch_size,
1,
src_len,
- ), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
+ ), (
+ "`attention_mask` should be `None` or of shape attention_mask.size() =="
+ f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
+ )
if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask
@@ -735,9 +747,10 @@ def forward(
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
@@ -757,7 +770,10 @@ def forward(
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
- ), f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}"
+ ), (
+ f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
+ f" shape {attn_output.size()}"
+ )
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
@@ -847,7 +863,10 @@ def forward(
batch_size,
ngram_sequence_length,
hidden_size,
- ], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}"
+ ], (
+ f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
+ f" {hidden_states.shape}"
+ )
# project
query_states = self.query_proj(hidden_states)
@@ -916,9 +935,10 @@ def forward(
).type_as(main_attn_weights)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
@@ -979,9 +999,10 @@ def forward(
).type_as(predict_attn_weights)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
@@ -1317,7 +1338,7 @@ def forward(
if attention_mask is not None:
extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
- ) * -10000.0
+ ) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else:
extended_attention_mask = None
@@ -1388,7 +1409,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
embeddings instead of randomly initialized word embeddings.
"""
- def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None):
+ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
super().__init__(config)
self.ngram = config.ngram
@@ -1535,7 +1556,7 @@ def forward(
if encoder_attention_mask is not None:
extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
- ) * -10000.0
+ ) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else:
extended_encoder_attention_mask = None
@@ -1559,9 +1580,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
# grad cannot be kept because tensor is sliced
@@ -1693,7 +1715,10 @@ def prepare_attention_mask(self, hidden_states, attention_mask):
# get causal mask
causal_mask = torch.full(
- (seq_length, seq_length), -float("inf"), dtype=hidden_states.dtype, device=hidden_states.device
+ (seq_length, seq_length),
+ torch.finfo(hidden_states.dtype).min,
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
)
causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
@@ -1702,7 +1727,7 @@ def prepare_attention_mask(self, hidden_states, attention_mask):
# add usual attention mask
if attention_mask is not None:
- extended_attention_mask = (1.0 - attention_mask[:, None, :]) * -10000.0
+ extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask
else:
extended_attention_mask = extended_causal_mask
@@ -1730,7 +1755,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
# add usual attention mask
if attention_mask is not None:
- extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * -10000.0
+ extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
# predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat(
@@ -1749,7 +1774,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
@@ -2081,11 +2106,12 @@ def get_decoder(self):
@add_start_docstrings(
- "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal language modeling.",
+ "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal"
+ " language modeling.",
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
config.is_decoder = True
@@ -2320,7 +2346,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
classes.
"""
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)
diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py
index 5bc3951b7969..c77259740390 100644
--- a/src/transformers/models/prophetnet/tokenization_prophetnet.py
+++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py
@@ -15,7 +15,7 @@
import collections
import os
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@@ -28,7 +28,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/prophetnet-large-uncased": "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
+ "microsoft/prophetnet-large-uncased": (
+ "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer"
+ ),
}
}
@@ -109,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
def __init__(
self,
- vocab_file,
- do_lower_case=True,
- do_basic_tokenize=True,
- never_split=None,
- unk_token="[UNK]",
- sep_token="[SEP]",
- x_sep_token="[X_SEP]",
- pad_token="[PAD]",
- mask_token="[MASK]",
- tokenize_chinese_chars=True,
- strip_accents=None,
+ vocab_file: str,
+ do_lower_case: Optional[bool] = True,
+ do_basic_tokenize: Optional[bool] = True,
+ never_split: Optional[Iterable] = None,
+ unk_token: Optional[str] = "[UNK]",
+ sep_token: Optional[str] = "[SEP]",
+ x_sep_token: Optional[str] = "[X_SEP]",
+ pad_token: Optional[str] = "[PAD]",
+ mask_token: Optional[str] = "[MASK]",
+ tokenize_chinese_chars: Optional[bool] = True,
+ strip_accents: Optional[bool] = None,
**kwargs
):
super().__init__(
@@ -139,8 +141,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -175,21 +177,24 @@ def _tokenize(self, text):
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
- def _convert_token_to_id(self, token):
+ def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
- def _convert_id_to_token(self, index):
+ def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
- def convert_tokens_to_string(self, tokens):
+ def convert_tokens_to_string(self, tokens: str):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def get_special_tokens_mask(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None,
+ already_has_special_tokens: Optional[bool] = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
diff --git a/src/transformers/models/qdqbert/__init__.py b/src/transformers/models/qdqbert/__init__.py
index 28fb61c2193c..60f03338f480 100644
--- a/src/transformers/models/qdqbert/__init__.py
+++ b/src/transformers/models/qdqbert/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
-}
+_import_structure = {"configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_qdqbert"] = [
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"QDQBertForMaskedLM",
@@ -44,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_qdqbert import (
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
QDQBertForMaskedLM,
diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py
index e7be1b45183d..35890625b1ff 100755
--- a/src/transformers/models/qdqbert/modeling_qdqbert.py
+++ b/src/transformers/models/qdqbert/modeling_qdqbert.py
@@ -19,11 +19,10 @@
import math
import os
import warnings
-from typing import Optional
+from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -40,7 +39,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_than_1_6, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -62,8 +61,9 @@
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
except OSError:
logger.error(
- "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. "
- "Please try to reinstall it following the instructions here: https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
+ "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
+ " following the instructions here:"
+ " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
)
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
@@ -76,7 +76,7 @@
]
-def load_tf_weights_in_qdqbert(model, config, tf_checkpoint_path):
+def load_tf_weights_in_qdqbert(model, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
@@ -166,7 +166,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
@@ -507,7 +507,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -848,7 +849,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
- def __init__(self, config, add_pooling_layer=True):
+ def __init__(self, config, add_pooling_layer: bool = True):
requires_backends(self, "pytorch_quantization")
super().__init__(config)
self.config = config
@@ -867,7 +868,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
- def _prune_heads(self, heads_to_prune):
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
@@ -884,20 +885,20 @@ class PreTrainedModel
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1043,21 +1044,21 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1144,7 +1145,13 @@ def forward(
cross_attentions=outputs.cross_attentions,
)
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ past=None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs
+ ):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
@@ -1199,19 +1206,19 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1254,7 +1261,9 @@ def forward(
attentions=outputs.attentions,
)
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ def prepare_inputs_for_generation(
+ self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs
+ ):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
@@ -1289,18 +1298,18 @@ def __init__(self, config):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
**kwargs,
- ):
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
@@ -1331,7 +1340,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
@@ -1399,17 +1409,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1496,17 +1506,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1592,17 +1602,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1673,18 +1683,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/rag/__init__.py b/src/transformers/models/rag/__init__.py
index 00e88f7c0abd..7798e8a41574 100644
--- a/src/transformers/models/rag/__init__.py
+++ b/src/transformers/models/rag/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -27,7 +27,12 @@
"tokenization_rag": ["RagTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
@@ -35,7 +40,12 @@
"RagTokenForGeneration",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
@@ -49,10 +59,20 @@
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
diff --git a/src/transformers/models/rag/configuration_rag.py b/src/transformers/models/rag/configuration_rag.py
index 2897642a7547..6046b934cd64 100644
--- a/src/transformers/models/rag/configuration_rag.py
+++ b/src/transformers/models/rag/configuration_rag.py
@@ -28,7 +28,7 @@
title_sep (`str`, *optional*, defaults to `" / "`):
Separator inserted between the title and the text of the retrieved document when calling [`RagRetriever`].
doc_sep (`str`, *optional*, defaults to `" // "`):
- Separator inserted between the the text of the retrieved document and the original input when calling
+ Separator inserted between the text of the retrieved document and the original input when calling
[`RagRetriever`].
n_docs (`int`, *optional*, defaults to 5):
Number of documents to retrieve.
diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py
index 205e825cbc1f..41af393c6710 100644
--- a/src/transformers/models/rag/modeling_rag.py
+++ b/src/transformers/models/rag/modeling_rag.py
@@ -336,9 +336,10 @@ def from_pretrained_question_encoder_generator(
# by the value of the flag `is_generator` that we need to set correctly.
question_encoder = kwargs_question_encoder.pop("model", None)
if question_encoder is None:
- assert (
- question_encoder_pretrained_model_name_or_path is not None
- ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined"
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
from ..auto.modeling_auto import AutoModel
if "config" not in kwargs_question_encoder:
@@ -357,9 +358,10 @@ def from_pretrained_question_encoder_generator(
generator = kwargs_generator.pop("model", None)
if generator is None:
- assert (
- generator_pretrained_model_name_or_path is not None
- ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined"
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
if "config" not in kwargs_generator:
@@ -654,23 +656,27 @@ def forward(
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
).squeeze(1)
else:
- assert (
- context_input_ids is not None
- ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
assert (
doc_scores is not None
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
- assert (
- doc_scores.shape[1] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# Decoder input without context documents
if decoder_input_ids is not None:
@@ -812,8 +818,7 @@ def forward(
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
+ >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
@@ -1022,12 +1027,14 @@ def generate(
new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
else: # input_ids is None, need context_input_ids/mask and doc_scores
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
individual_input_ids = generator_input_ids.repeat(
num_candidates, 1
@@ -1279,8 +1286,7 @@ def forward(
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
- >>> with tokenizer.as_target_tokenizer():
- ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
+ >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
@@ -1567,9 +1573,10 @@ def generate(
1
)
- assert (
- context_input_ids.shape[0] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# batch_size
batch_size = context_input_ids.shape[0] // n_docs
diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py
index 30f50a29ff40..26482026baa8 100644
--- a/src/transformers/models/rag/modeling_tf_rag.py
+++ b/src/transformers/models/rag/modeling_tf_rag.py
@@ -321,9 +321,10 @@ def from_pretrained_question_encoder_generator(
# by the value of the flag `is_generator` that we need to set correctly.
question_encoder = kwargs_question_encoder.pop("model", None)
if question_encoder is None:
- assert (
- question_encoder_pretrained_model_name_or_path is not None
- ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined"
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
from ..auto.modeling_tf_auto import TFAutoModel
@@ -343,9 +344,10 @@ def from_pretrained_question_encoder_generator(
generator = kwargs_generator.pop("generator", None)
if generator is None:
- assert (
- generator_pretrained_model_name_or_path is not None
- ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined"
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
@@ -632,23 +634,27 @@ def call(
)
else:
- assert (
- context_input_ids is not None
- ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
assert (
doc_scores is not None
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
- assert (
- doc_scores.shape[1] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# Decoder input without context documents
if decoder_input_ids is not None:
@@ -1149,9 +1155,10 @@ def generate(
)
doc_scores = tf.squeeze(doc_scores, axis=1)
- assert (
- context_input_ids.shape[0] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
batch_size = context_input_ids.shape[0] // n_docs
@@ -1286,9 +1293,10 @@ def shift_tokens_right(self, input_ids, start_token_id=None):
if start_token_id is None:
start_token_id = self.generator.config.decoder_start_token_id
- assert (
- start_token_id is not None
- ), "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as generator, see Bart docs for more information"
+ assert start_token_id is not None, (
+ "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as"
+ " generator, see Bart docs for more information"
+ )
pad_token_id = self.generator.config.pad_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
@@ -1325,6 +1333,8 @@ def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
"""CrossEntropyLoss that ignores pad tokens"""
+ # Matt: As written, this loss is not XLA-compatible, but it's doing some very weird things
+ # and I don't feel comfortable converting it.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.SUM,
@@ -1745,12 +1755,14 @@ def generate(
new_input_ids = tf.tile(input_ids[index : index + 1], (num_candidates, 1))
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
else: # input_ids is None, need context_input_ids/mask and doc_scores
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
individual_input_ids = tf.tile(
generator_input_ids, (num_candidates, 1)
diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py
index 7a3c5635f24f..797c1a7332ac 100644
--- a/src/transformers/models/rag/retrieval_rag.py
+++ b/src/transformers/models/rag/retrieval_rag.py
@@ -23,7 +23,7 @@
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
-from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends
+from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer
@@ -111,22 +111,21 @@ def __init__(self, vector_size, index_path):
self._index_initialized = False
def _resolve_path(self, index_path, filename):
- assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`."
- archive_file = os.path.join(index_path, filename)
+ is_local = os.path.isdir(index_path)
try:
# Load from URL or cache if already cached
- resolved_archive_file = cached_path(archive_file)
+ resolved_archive_file = cached_file(index_path, filename)
except EnvironmentError:
msg = (
- f"Can't load '{archive_file}'. Make sure that:\n\n"
+ f"Can't load '{filename}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
)
raise EnvironmentError(msg)
- if resolved_archive_file == archive_file:
- logger.info(f"loading file {archive_file}")
+ if is_local:
+ logger.info(f"loading file {resolved_archive_file}")
else:
- logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}")
+ logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
return resolved_archive_file
def _load_passages(self):
diff --git a/src/transformers/models/rag/tokenization_rag.py b/src/transformers/models/rag/tokenization_rag.py
index d92ca1788faa..5b6ec67e6bf8 100644
--- a/src/transformers/models/rag/tokenization_rag.py
+++ b/src/transformers/models/rag/tokenization_rag.py
@@ -15,7 +15,6 @@
"""Tokenization classes for RAG."""
import os
import warnings
-from contextlib import contextmanager
from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding
@@ -68,16 +67,12 @@ def batch_decode(self, *args, **kwargs):
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.current_tokenizer = self.generator
- yield
+ def _switch_to_input_mode(self):
self.current_tokenizer = self.question_encoder
+ def _switch_to_target_mode(self):
+ self.current_tokenizer = self.generator
+
def prepare_seq2seq_batch(
self,
src_texts: List[str],
@@ -110,17 +105,16 @@ def prepare_seq2seq_batch(
if tgt_texts is None:
return model_inputs
# Process tgt_texts
- with self.as_target_tokenizer():
- if max_target_length is None:
- max_target_length = self.current_tokenizer.model_max_length
- labels = self(
- tgt_texts,
- add_special_tokens=True,
- return_tensors=return_tensors,
- padding=padding,
- max_length=max_target_length,
- truncation=truncation,
- **kwargs,
- )
+ if max_target_length is None:
+ max_target_length = self.current_tokenizer.model_max_length
+ labels = self(
+ text_target=tgt_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_target_length,
+ truncation=truncation,
+ **kwargs,
+ )
model_inputs["labels"] = labels["input_ids"]
return model_inputs
diff --git a/src/transformers/models/realm/__init__.py b/src/transformers/models/realm/__init__.py
index db113dbd5b29..2464c0ae27d9 100644
--- a/src/transformers/models/realm/__init__.py
+++ b/src/transformers/models/realm/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -25,10 +25,20 @@
"tokenization_realm": ["RealmTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_realm_fast"] = ["RealmTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_realm"] = [
"REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
"RealmEmbedder",
@@ -46,10 +56,20 @@
from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig
from .tokenization_realm import RealmTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_realm import RealmTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_realm import (
REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
RealmEmbedder,
diff --git a/src/transformers/models/realm/configuration_realm.py b/src/transformers/models/realm/configuration_realm.py
index d3383bd897c5..8d816a736e7a 100644
--- a/src/transformers/models/realm/configuration_realm.py
+++ b/src/transformers/models/realm/configuration_realm.py
@@ -21,10 +21,18 @@
logger = logging.get_logger(__name__)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py
index c467dcd30af8..6ee2b1fd14b4 100644
--- a/src/transformers/models/realm/modeling_realm.py
+++ b/src/transformers/models/realm/modeling_realm.py
@@ -20,7 +20,6 @@
from typing import Optional, Tuple, Union
import torch
-from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -32,7 +31,12 @@
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig
@@ -181,7 +185,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
@@ -502,7 +506,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -868,8 +873,8 @@ def _spans_given_width(width):
return starts, ends, span_masks
- def mask_to_score(mask):
- return (1.0 - mask.type(torch.float32)) * -10000.0
+ def mask_to_score(mask, dtype=torch.float32):
+ return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
# [reader_beam_size, max_sequence_len, span_hidden_size * 2]
hidden_states = self.dense_intermediate(hidden_states)
@@ -889,7 +894,7 @@ def mask_to_score(mask):
# [reader_beam_size, num_candidates]
reader_logits = self.dense_output(candidate_hidden).squeeze(-1)
# [reader_beam_size, num_candidates]
- reader_logits += mask_to_score(candidate_mask)
+ reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype)
return reader_logits, candidate_starts, candidate_ends
@@ -1366,7 +1371,8 @@ def forward(
@add_start_docstrings(
- "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood loss.",
+ "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood"
+ " loss.",
REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
@@ -1632,11 +1638,11 @@ def compute_correct_candidates(candidate_starts, candidate_ends, gold_starts, go
def marginal_log_loss(logits, is_correct):
"""Loss based on the negative marginal log-likelihood."""
- def mask_to_score(mask):
- return (1.0 - mask.type(torch.float32)) * -10000.0
+ def mask_to_score(mask, dtype=torch.float32):
+ return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
# []
- log_numerator = torch.logsumexp(logits + mask_to_score(is_correct), dim=-1)
+ log_numerator = torch.logsumexp(logits + mask_to_score(is_correct, dtype=logits.dtype), dim=-1)
log_denominator = torch.logsumexp(logits, dim=-1)
return log_denominator - log_numerator
diff --git a/src/transformers/models/realm/tokenization_realm.py b/src/transformers/models/realm/tokenization_realm.py
index 426b5d775cf9..63295826d462 100644
--- a/src/transformers/models/realm/tokenization_realm.py
+++ b/src/transformers/models/realm/tokenization_realm.py
@@ -30,10 +30,18 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
@@ -165,8 +173,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/realm/tokenization_realm_fast.py b/src/transformers/models/realm/tokenization_realm_fast.py
index 87580baa228b..f61fa8418ed2 100644
--- a/src/transformers/models/realm/tokenization_realm_fast.py
+++ b/src/transformers/models/realm/tokenization_realm_fast.py
@@ -31,24 +31,48 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
},
"tokenizer_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
- "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
- "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json",
- "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
- "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-nq-openqa": (
+ "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-nq-reader": (
+ "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-wq-openqa": (
+ "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-wq-reader": (
+ "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/reformer/__init__.py b/src/transformers/models/reformer/__init__.py
index 3c6130301b53..979074bcc728 100644
--- a/src/transformers/models/reformer/__init__.py
+++ b/src/transformers/models/reformer/__init__.py
@@ -18,20 +18,39 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
-}
+_import_structure = {"configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_reformer"] = ["ReformerTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_reformer_fast"] = ["ReformerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_reformer"] = [
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ReformerAttention",
@@ -48,13 +67,28 @@
if TYPE_CHECKING:
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_reformer import ReformerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_reformer_fast import ReformerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
diff --git a/src/transformers/models/reformer/configuration_reformer.py b/src/transformers/models/reformer/configuration_reformer.py
index d481b3b13768..ea2a1abd0825 100755
--- a/src/transformers/models/reformer/configuration_reformer.py
+++ b/src/transformers/models/reformer/configuration_reformer.py
@@ -22,7 +22,9 @@
logger = logging.get_logger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json",
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json"
+ ),
"google/reformer-enwik8": "https://huggingface.co/google/reformer-enwik8/resolve/main/config.json",
}
diff --git a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
index 2e2e3f3a60dd..f25e166ef917 100755
--- a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
+++ b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
@@ -210,8 +210,10 @@ def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained Reformer model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained Reformer model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py
index 089481f8542e..8430f3a62c0d 100755
--- a/src/transformers/models/reformer/modeling_reformer.py
+++ b/src/transformers/models/reformer/modeling_reformer.py
@@ -380,9 +380,10 @@ def forward(
# check if cache shall be used and that hidden states are already cached
if do_cached_attention:
- assert (
- sequence_length == 1
- ), f"At the moment, auto-regressive language generation is only possible one word at a time. Make sure that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed."
+ assert sequence_length == 1, (
+ "At the moment, auto-regressive language generation is only possible one word at a time. Make sure"
+ f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed."
+ )
past_buckets = past_buckets_states[0]
past_states = past_buckets_states[1]
@@ -505,9 +506,10 @@ def forward(
)
if self.chunk_length is None:
- assert (
- self.num_chunks_before == 0 and self.num_chunks_after == 0
- ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
+ assert self.num_chunks_before == 0 and self.num_chunks_after == 0, (
+ "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and"
+ " `config.num_chunks_before` are set to 0."
+ )
elif do_cached_attention and past_buckets is not None:
# use max sequence length
sorted_bucket_idx_per_hash = sorted_bucket_idx
@@ -577,7 +579,10 @@ def forward(
self.num_attention_heads,
sequence_length,
self.attention_head_size,
- ), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`."
+ ), (
+ "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length,"
+ " config.attention_head_size]`."
+ )
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
@@ -891,7 +896,10 @@ def _get_relevant_hid_states_and_buckets(
self.num_attention_heads,
num_hashes,
sequence_length,
- ), f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but has shape {bucket_idx.shape}."
+ ), (
+ f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but"
+ f" has shape {bucket_idx.shape}."
+ )
# find indices of new bucket indices
relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero()
@@ -925,12 +933,20 @@ def _get_relevant_hid_states_and_buckets(
assert (
relevant_hidden_states.shape[2]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes
- ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`, there are {relevant_hidden_states.shape[2]} `hidden_states`."
+ ), (
+ "There should be"
+ f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`,"
+ f" there are {relevant_hidden_states.shape[2]} `hidden_states`."
+ )
assert (
relevant_bucket_idx_chunk.shape[-1]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length
- ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`."
+ ), (
+ "There should be"
+ f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are"
+ f" {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`."
+ )
return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets
@@ -1054,9 +1070,10 @@ def forward(
# check if cache shall be used and that hidden states are already cached
if use_cache and past_buckets_states[1] is not None:
- assert (
- past_buckets_states[0] is None
- ), "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching hidden_states_and_buckets."
+ assert past_buckets_states[0] is None, (
+ "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching"
+ " hidden_states_and_buckets."
+ )
key_value_hidden_states = self._retrieve_relevant_hidden_states(
past_buckets_states[1], self.chunk_length, self.num_chunks_before
)
@@ -1092,9 +1109,10 @@ def forward(
), f"last dim of query_key_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}."
if self.chunk_length is None:
- assert (
- self.num_chunks_before == 0 and self.num_chunks_after == 0
- ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
+ assert self.num_chunks_before == 0 and self.num_chunks_after == 0, (
+ "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and"
+ " `config.num_chunks_before` are set to 0."
+ )
# normalize key vectors
key_vectors = key_vectors / torch.sqrt(
@@ -1514,9 +1532,10 @@ def backward_pass(
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
- assert (
- self.training
- ), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
+ assert self.training, (
+ "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the"
+ " model into training mode."
+ )
with torch.enable_grad():
next_attn_output.requires_grad = True
@@ -1957,7 +1976,7 @@ class ReformerModelWithLMHeadOutput(ModelOutput):
@add_start_docstrings(
- "The bare Reformer Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare Reformer Model transformer outputting raw hidden-stateswithout any specific head on top.",
REFORMER_START_DOCSTRING,
)
class ReformerModel(ReformerPreTrainedModel):
@@ -2176,12 +2195,14 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
- assert (
- "local" not in self.config.attn_layers or config.local_num_chunks_after == 0
- ), f"If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not {config.local_num_chunks_after}."
- assert (
- "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0
- ), f"If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not {config.lsh_num_chunks_after}."
+ assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, (
+ "If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not"
+ f" {config.local_num_chunks_after}."
+ )
+ assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, (
+ "If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not"
+ f" {config.lsh_num_chunks_after}."
+ )
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
@@ -2296,9 +2317,10 @@ def _reorder_cache(self, past, beam_idx):
class ReformerForMaskedLM(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
- assert (
- not config.is_decoder
- ), "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
+ assert not config.is_decoder, (
+ "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional"
+ " self-attention."
+ )
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py
index 8c75dda15e70..d5d73f3e451f 100644
--- a/src/transformers/models/reformer/tokenization_reformer.py
+++ b/src/transformers/models/reformer/tokenization_reformer.py
@@ -34,7 +34,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ )
}
}
diff --git a/src/transformers/models/reformer/tokenization_reformer_fast.py b/src/transformers/models/reformer/tokenization_reformer_fast.py
index e6a848379159..e9c6a61993d0 100644
--- a/src/transformers/models/reformer/tokenization_reformer_fast.py
+++ b/src/transformers/models/reformer/tokenization_reformer_fast.py
@@ -38,10 +38,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ )
},
"tokenizer_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
+ )
},
}
diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py
index 185ead37b640..5399cb3f3be0 100644
--- a/src/transformers/models/regnet/__init__.py
+++ b/src/transformers/models/regnet/__init__.py
@@ -18,14 +18,17 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...file_utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
-_import_structure = {
- "configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"],
-}
+_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_regnet"] = [
"REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"RegNetForImageClassification",
@@ -33,11 +36,29 @@
"RegNetPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_regnet"] = [
+ "TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFRegNetForImageClassification",
+ "TFRegNetModel",
+ "TFRegNetPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_regnet import (
REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
RegNetForImageClassification,
@@ -45,6 +66,19 @@
RegNetPreTrainedModel,
)
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_regnet import (
+ TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFRegNetForImageClassification,
+ TFRegNetModel,
+ TFRegNetPreTrainedModel,
+ )
+
else:
import sys
diff --git a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
index 8024ef679201..a43967d0095d 100644
--- a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
+++ b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
@@ -277,7 +277,10 @@ def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]:
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/regnet/convert_regnet_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_to_pytorch.py
index 96e4ab700ab5..9bb0ba0f0532 100644
--- a/src/transformers/models/regnet/convert_regnet_to_pytorch.py
+++ b/src/transformers/models/regnet/convert_regnet_to_pytorch.py
@@ -84,7 +84,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced) and self.raise_if_mismatch:
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -431,7 +432,10 @@ def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Mo
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py
index 0ebd05a25ce1..64b14dc54de8 100644
--- a/src/transformers/models/regnet/modeling_regnet.py
+++ b/src/transformers/models/regnet/modeling_regnet.py
@@ -45,7 +45,7 @@
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
-_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/regnet-y-040",
@@ -93,14 +93,20 @@ def __init__(self, config: RegNetConfig):
self.embedder = RegNetConvLayer(
config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
)
+ self.num_channels = config.num_channels
- def forward(self, hidden_state):
- hidden_state = self.embedder(hidden_state)
+ def forward(self, pixel_values):
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ hidden_state = self.embedder(pixel_values)
return hidden_state
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
-class RegNetShortCut(nn.Sequential):
+class RegNetShortCut(nn.Module):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
@@ -111,6 +117,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
class RegNetSELayer(nn.Module):
"""
diff --git a/src/transformers/models/regnet/modeling_tf_regnet.py b/src/transformers/models/regnet/modeling_tf_regnet.py
new file mode 100644
index 000000000000..1d43d6eb7f8b
--- /dev/null
+++ b/src/transformers/models/regnet/modeling_tf_regnet.py
@@ -0,0 +1,523 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TensorFlow RegNet model."""
+
+from typing import Dict, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import ACT2FN
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithNoAttention,
+ TFBaseModelOutputWithPoolingAndNoAttention,
+ TFSequenceClassifierOutput,
+)
+from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
+from ...tf_utils import shape_list
+from ...utils import logging
+from .configuration_regnet import RegNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "RegNetConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
+_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/regnet-y-040",
+ # See all regnet models at https://huggingface.co/models?filter=regnet
+]
+
+
+class TFRegNetConvLayer(tf.keras.layers.Layer):
+ def __init__(
+ self,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ groups: int = 1,
+ activation: Optional[str] = "relu",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ # The padding and conv has been verified in
+ # https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb
+ self.padding = tf.keras.layers.ZeroPadding2D(padding=kernel_size // 2)
+ self.convolution = tf.keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding="VALID",
+ groups=groups,
+ use_bias=False,
+ name="convolution",
+ )
+ self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
+ self.activation = ACT2FN[activation] if activation is not None else tf.identity
+
+ def call(self, hidden_state):
+ hidden_state = self.convolution(self.padding(hidden_state))
+ hidden_state = self.normalization(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFRegNetEmbeddings(tf.keras.layers.Layer):
+ """
+ RegNet Embeddings (stem) composed of a single aggressive convolution.
+ """
+
+ def __init__(self, config: RegNetConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.num_channels = config.num_channels
+ self.embedder = TFRegNetConvLayer(
+ out_channels=config.embedding_size,
+ kernel_size=3,
+ stride=2,
+ activation=config.hidden_act,
+ name="embedder",
+ )
+
+ def call(self, pixel_values):
+ num_channels = shape_list(pixel_values)[1]
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+ hidden_state = self.embedder(pixel_values)
+ return hidden_state
+
+
+class TFRegNetShortCut(tf.keras.layers.Layer):
+ """
+ RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
+ downsample the input using `stride=2`.
+ """
+
+ def __init__(self, out_channels: int, stride: int = 2, **kwargs):
+ super().__init__(**kwargs)
+ self.convolution = tf.keras.layers.Conv2D(
+ filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
+ )
+ self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
+
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
+ return self.normalization(self.convolution(inputs), training=training)
+
+
+class TFRegNetSELayer(tf.keras.layers.Layer):
+ """
+ Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
+ """
+
+ def __init__(self, in_channels: int, reduced_channels: int, **kwargs):
+ super().__init__(**kwargs)
+ self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
+ self.attention = [
+ tf.keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
+ tf.keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
+ ]
+
+ def call(self, hidden_state):
+ # [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels]
+ pooled = self.pooler(hidden_state)
+ for layer_module in self.attention:
+ pooled = layer_module(pooled)
+ hidden_state = hidden_state * pooled
+ return hidden_state
+
+
+class TFRegNetXLayer(tf.keras.layers.Layer):
+ """
+ RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
+ """
+
+ def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
+ super().__init__(**kwargs)
+ should_apply_shortcut = in_channels != out_channels or stride != 1
+ groups = max(1, out_channels // config.groups_width)
+ self.shortcut = (
+ TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
+ if should_apply_shortcut
+ else tf.keras.layers.Activation("linear", name="shortcut")
+ )
+ # `self.layers` instead of `self.layer` because that is a reserved argument.
+ self.layers = [
+ TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
+ TFRegNetConvLayer(
+ out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
+ ),
+ TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2"),
+ ]
+ self.activation = ACT2FN[config.hidden_act]
+
+ def call(self, hidden_state):
+ residual = hidden_state
+ for layer_module in self.layers:
+ hidden_state = layer_module(hidden_state)
+ residual = self.shortcut(residual)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFRegNetYLayer(tf.keras.layers.Layer):
+ """
+ RegNet's Y layer: an X layer with Squeeze and Excitation.
+ """
+
+ def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
+ super().__init__(**kwargs)
+ should_apply_shortcut = in_channels != out_channels or stride != 1
+ groups = max(1, out_channels // config.groups_width)
+ self.shortcut = (
+ TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
+ if should_apply_shortcut
+ else tf.keras.layers.Activation("linear", name="shortcut")
+ )
+ self.layers = [
+ TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
+ TFRegNetConvLayer(
+ out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
+ ),
+ TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
+ TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.3"),
+ ]
+ self.activation = ACT2FN[config.hidden_act]
+
+ def call(self, hidden_state):
+ residual = hidden_state
+ for layer_module in self.layers:
+ hidden_state = layer_module(hidden_state)
+ residual = self.shortcut(residual)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFRegNetStage(tf.keras.layers.Layer):
+ """
+ A RegNet stage composed by stacked layers.
+ """
+
+ def __init__(
+ self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer
+ self.layers = [
+ # downsampling is done in the first layer with stride of 2
+ layer(config, in_channels, out_channels, stride=stride, name="layers.0"),
+ *[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)],
+ ]
+
+ def call(self, hidden_state):
+ for layer_module in self.layers:
+ hidden_state = layer_module(hidden_state)
+ return hidden_state
+
+
+class TFRegNetEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: RegNetConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.stages = list()
+ # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
+ self.stages.append(
+ TFRegNetStage(
+ config,
+ config.embedding_size,
+ config.hidden_sizes[0],
+ stride=2 if config.downsample_in_first_stage else 1,
+ depth=config.depths[0],
+ name="stages.0",
+ )
+ )
+ in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
+ for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])):
+ self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}"))
+
+ def call(
+ self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True
+ ) -> TFBaseModelOutputWithNoAttention:
+ hidden_states = () if output_hidden_states else None
+
+ for stage_module in self.stages:
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ hidden_state = stage_module(hidden_state)
+
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states] if v is not None)
+
+ return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
+
+
+@keras_serializable
+class TFRegNetMainLayer(tf.keras.layers.Layer):
+ config_class = RegNetConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.embedder = TFRegNetEmbeddings(config, name="embedder")
+ self.encoder = TFRegNetEncoder(config, name="encoder")
+ self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPoolingAndNoAttention:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ embedding_output = self.embedder(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = self.pooler(last_hidden_state)
+
+ # Change to NCHW output format have uniformity in the modules
+ pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+
+class TFRegNetPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = RegNetConfig
+ base_model_prefix = "regnet"
+ main_input_name = "pixel_values"
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network.
+
+ Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ """
+ Method used for serving the model.
+
+ Args:
+ inputs (`Dict[str, tf.Tensor]`):
+ The input of the saved model as a dictionary of tensors.
+ """
+ output = self.call(inputs)
+ return self.serving_output(output)
+
+
+REGNET_START_DOCSTRING = r"""
+ Parameters:
+ This model is a Tensorflow
+ [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
+ regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
+ behavior.
+ config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+REGNET_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare RegNet model outputting raw features without any specific head on top.",
+ REGNET_START_DOCSTRING,
+)
+class TFRegNetModel(TFRegNetPreTrainedModel):
+ def __init__(self, config: RegNetConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.regnet = TFRegNetMainLayer(config, name="regnet")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training=False,
+ ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.regnet(
+ pixel_values=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ if not return_dict:
+ return (outputs[0],) + outputs[1:]
+
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=outputs.last_hidden_state,
+ pooler_output=outputs.pooler_output,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def serving_output(
+ self, output: TFBaseModelOutputWithPoolingAndNoAttention
+ ) -> TFBaseModelOutputWithPoolingAndNoAttention:
+ # hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=output.last_hidden_state,
+ pooler_output=output.pooler_output,
+ hidden_states=output.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ REGNET_START_DOCSTRING,
+)
+class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: RegNetConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.regnet = TFRegNetMainLayer(config, name="regnet")
+ # classification head
+ self.classifier = [
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity,
+ ]
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor = None,
+ labels: tf.Tensor = None,
+ output_hidden_states: bool = None,
+ return_dict: bool = None,
+ training=False,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.regnet(
+ pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ flattened_output = self.classifier[0](pooled_output)
+ logits = self.classifier[1](flattened_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+ def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
+ # hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSequenceClassifierOutput(logits=output.logits, hidden_states=output.hidden_states)
diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py
index fb5defeee5d0..10af6c4d27f3 100644
--- a/src/transformers/models/rembert/__init__.py
+++ b/src/transformers/models/rembert/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
-}
+_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_rembert"] = ["RemBertTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_rembert_fast"] = ["RemBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_rembert"] = [
"REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"RemBertForCausalLM",
@@ -53,7 +67,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_rembert"] = [
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRemBertForCausalLM",
@@ -71,13 +90,28 @@
if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_rembert import RemBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_rembert_fast import RemBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_rembert import (
REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
RemBertForCausalLM,
@@ -92,7 +126,12 @@
load_tf_weights_in_rembert,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py
index 589c40bdcb98..732d75c5cc2b 100644
--- a/src/transformers/models/rembert/configuration_rembert.py
+++ b/src/transformers/models/rembert/configuration_rembert.py
@@ -21,7 +21,7 @@
logger = logging.get_logger(__name__)
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
+ "google/rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
# See all RemBERT models at https://huggingface.co/models?filter=rembert
}
@@ -80,16 +80,17 @@ class RemBertConfig(PretrainedConfig):
Example:
```python
+ >>> from transformers import RemBertModel, RemBertConfig
- ```
+ >>> # Initializing a RemBERT rembert style configuration
+ >>> configuration = RemBertConfig()
- >>> from transformers import RemBertModel, RemBertConfig >>> # Initializing a RemBERT rembert style
- configuration >>> configuration = RemBertConfig()
+ >>> # Initializing a model from the rembert style configuration
+ >>> model = RemBertModel(configuration)
- >>> # Initializing a model from the rembert style configuration >>> model = RemBertModel(configuration)
-
- >>> # Accessing the model configuration >>> configuration = model.config
- """
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
model_type = "rembert"
def __init__(
diff --git a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
index 2a3c497d37a8..4c3d53e789de 100755
--- a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,10 @@ def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_fil
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained RemBERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained RemBERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py
index c7b8da35a272..b6c20cb689d8 100755
--- a/src/transformers/models/rembert/modeling_rembert.py
+++ b/src/transformers/models/rembert/modeling_rembert.py
@@ -460,7 +460,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -785,7 +786,7 @@ class PreTrainedModel
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -938,7 +939,7 @@ def set_output_embeddings(self, new_embeddings):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1183,7 +1184,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1280,7 +1281,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1373,7 +1374,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1452,7 +1453,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py
index c039f2635037..2e25dafed483 100644
--- a/src/transformers/models/rembert/modeling_tf_rembert.py
+++ b/src/transformers/models/rembert/modeling_tf_rembert.py
@@ -414,8 +414,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -938,7 +938,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -1041,7 +1041,7 @@ def get_lm_head(self) -> tf.keras.layers.Layer:
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1131,7 +1131,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
@unpack_inputs
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -1262,7 +1262,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1352,7 +1352,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1471,7 +1471,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1550,7 +1550,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py
index 8a839228f872..f62c2999671d 100644
--- a/src/transformers/models/resnet/__init__.py
+++ b/src/transformers/models/resnet/__init__.py
@@ -18,14 +18,19 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
- "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
+ "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"]
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_resnet"] = [
"RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ResNetForImageClassification",
@@ -33,11 +38,29 @@
"ResNetPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_resnet"] = [
+ "TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFResNetForImageClassification",
+ "TFResNetModel",
+ "TFResNetPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
+ from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetForImageClassification,
@@ -45,6 +68,19 @@
ResNetPreTrainedModel,
)
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_resnet import (
+ TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFResNetForImageClassification,
+ TFResNetModel,
+ TFResNetPreTrainedModel,
+ )
+
else:
import sys
diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py
index 8e5f6e656d1f..9bfc694bb144 100644
--- a/src/transformers/models/resnet/configuration_resnet.py
+++ b/src/transformers/models/resnet/configuration_resnet.py
@@ -14,7 +14,13 @@
# limitations under the License.
""" ResNet model configuration"""
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -89,3 +95,20 @@ def __init__(
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
+
+
+class ResNetOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-3
diff --git a/src/transformers/models/resnet/convert_resnet_to_pytorch.py b/src/transformers/models/resnet/convert_resnet_to_pytorch.py
index 60973ecdec06..55a865ed5936 100644
--- a/src/transformers/models/resnet/convert_resnet_to_pytorch.py
+++ b/src/transformers/models/resnet/convert_resnet_to_pytorch.py
@@ -81,7 +81,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced):
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -173,7 +174,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported resnet* architecture,"
+ " currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py
index f2f555d7f519..d8804d960443 100644
--- a/src/transformers/models/resnet/modeling_resnet.py
+++ b/src/transformers/models/resnet/modeling_resnet.py
@@ -52,7 +52,7 @@
]
-class ResNetConvLayer(nn.Sequential):
+class ResNetConvLayer(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
@@ -63,8 +63,14 @@ def __init__(
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
-class ResNetEmbeddings(nn.Sequential):
+class ResNetEmbeddings(nn.Module):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
@@ -75,9 +81,20 @@ def __init__(self, config: ResNetConfig):
config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
)
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: Tensor) -> Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embedding = self.embedder(pixel_values)
+ embedding = self.pooler(embedding)
+ return embedding
-class ResNetShortCut(nn.Sequential):
+class ResNetShortCut(nn.Module):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
@@ -88,10 +105,15 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
class ResNetBasicLayer(nn.Module):
"""
- A classic ResNet's residual layer composed by a two `3x3` convolutions.
+ A classic ResNet's residual layer composed by two `3x3` convolutions.
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
@@ -117,10 +139,10 @@ def forward(self, hidden_state):
class ResNetBottleNeckLayer(nn.Module):
"""
- A classic ResNet's bottleneck layer composed by a three `3x3` convolutions.
+ A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
- convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`.
+ convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
"""
def __init__(
@@ -148,7 +170,7 @@ def forward(self, hidden_state):
return hidden_state
-class ResNetStage(nn.Sequential):
+class ResNetStage(nn.Module):
"""
A ResNet stage composed by stacked layers.
"""
@@ -171,6 +193,12 @@ def __init__(
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
class ResNetEncoder(nn.Module):
def __init__(self, config: ResNetConfig):
diff --git a/src/transformers/models/resnet/modeling_tf_resnet.py b/src/transformers/models/resnet/modeling_tf_resnet.py
new file mode 100644
index 000000000000..bed053ae404f
--- /dev/null
+++ b/src/transformers/models/resnet/modeling_tf_resnet.py
@@ -0,0 +1,501 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TensorFlow ResNet model."""
+
+from typing import Dict, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import ACT2FN
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithNoAttention,
+ TFBaseModelOutputWithPoolingAndNoAttention,
+ TFImageClassifierOutputWithNoAttention,
+)
+from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
+from ...tf_utils import shape_list
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_resnet import ResNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ResNetConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
+_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
+
+TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/resnet-50",
+ # See all resnet models at https://huggingface.co/models?filter=resnet
+]
+
+
+class TFResNetConvLayer(tf.keras.layers.Layer):
+ def __init__(
+ self, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu", **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.pad_value = kernel_size // 2
+ self.conv = tf.keras.layers.Conv2D(
+ out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution"
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent
+ self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
+ self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear")
+
+ def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
+ # Pad to match that done in the PyTorch Conv2D model
+ height_pad = width_pad = (self.pad_value, self.pad_value)
+ hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)])
+ hidden_state = self.conv(hidden_state)
+ return hidden_state
+
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_state = self.convolution(hidden_state)
+ hidden_state = self.normalization(hidden_state, training=training)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFResNetEmbeddings(tf.keras.layers.Layer):
+ """
+ ResNet Embeddings (stem) composed of a single aggressive convolution.
+ """
+
+ def __init__(self, config: ResNetConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.embedder = TFResNetConvLayer(
+ config.embedding_size,
+ kernel_size=7,
+ stride=2,
+ activation=config.hidden_act,
+ name="embedder",
+ )
+ self.pooler = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler")
+ self.num_channels = config.num_channels
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+ _, _, _, num_channels = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ hidden_state = pixel_values
+ hidden_state = self.embedder(hidden_state)
+ hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]])
+ hidden_state = self.pooler(hidden_state)
+ return hidden_state
+
+
+class TFResNetShortCut(tf.keras.layers.Layer):
+ """
+ ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
+ downsample the input using `stride=2`.
+ """
+
+ def __init__(self, out_channels: int, stride: int = 2, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.convolution = tf.keras.layers.Conv2D(
+ out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent
+ self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
+
+ def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_state = x
+ hidden_state = self.convolution(hidden_state)
+ hidden_state = self.normalization(hidden_state, training=training)
+ return hidden_state
+
+
+class TFResNetBasicLayer(tf.keras.layers.Layer):
+ """
+ A classic ResNet's residual layer composed by two `3x3` convolutions.
+ """
+
+ def __init__(
+ self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ should_apply_shortcut = in_channels != out_channels or stride != 1
+ self.conv1 = TFResNetConvLayer(out_channels, stride=stride, name="layer.0")
+ self.conv2 = TFResNetConvLayer(out_channels, activation=None, name="layer.1")
+ self.shortcut = (
+ TFResNetShortCut(out_channels, stride=stride, name="shortcut")
+ if should_apply_shortcut
+ else tf.keras.layers.Activation("linear", name="shortcut")
+ )
+ self.activation = ACT2FN[activation]
+
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
+ residual = hidden_state
+ hidden_state = self.conv1(hidden_state, training=training)
+ hidden_state = self.conv2(hidden_state, training=training)
+ residual = self.shortcut(residual, training=training)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
+ """
+ A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
+
+ The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
+ convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ activation: str = "relu",
+ reduction: int = 4,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ should_apply_shortcut = in_channels != out_channels or stride != 1
+ reduces_channels = out_channels // reduction
+ self.conv0 = TFResNetConvLayer(reduces_channels, kernel_size=1, name="layer.0")
+ self.conv1 = TFResNetConvLayer(reduces_channels, stride=stride, name="layer.1")
+ self.conv2 = TFResNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2")
+ self.shortcut = (
+ TFResNetShortCut(out_channels, stride=stride, name="shortcut")
+ if should_apply_shortcut
+ else tf.keras.layers.Activation("linear", name="shortcut")
+ )
+ self.activation = ACT2FN[activation]
+
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
+ residual = hidden_state
+ hidden_state = self.conv0(hidden_state, training=training)
+ hidden_state = self.conv1(hidden_state, training=training)
+ hidden_state = self.conv2(hidden_state, training=training)
+ residual = self.shortcut(residual, training=training)
+ hidden_state += residual
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class TFResNetStage(tf.keras.layers.Layer):
+ """
+ A ResNet stage composed of stacked layers.
+ """
+
+ def __init__(
+ self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+
+ layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer
+
+ layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")]
+ layers += [
+ layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}")
+ for i in range(depth - 1)
+ ]
+ self.stage_layers = layers
+
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
+ for layer in self.stage_layers:
+ hidden_state = layer(hidden_state, training=training)
+ return hidden_state
+
+
+class TFResNetEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: ResNetConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
+ self.stages = [
+ TFResNetStage(
+ config,
+ config.embedding_size,
+ config.hidden_sizes[0],
+ stride=2 if config.downsample_in_first_stage else 1,
+ depth=config.depths[0],
+ name="stages.0",
+ )
+ ]
+ for i, (in_channels, out_channels, depth) in enumerate(
+ zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:])
+ ):
+ self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}"))
+
+ def call(
+ self,
+ hidden_state: tf.Tensor,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithNoAttention:
+ hidden_states = () if output_hidden_states else None
+
+ for stage_module in self.stages:
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ hidden_state = stage_module(hidden_state, training=training)
+
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states] if v is not None)
+
+ return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
+
+
+class TFResNetPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ResNetConfig
+ base_model_prefix = "resnet"
+ main_input_name = "pixel_values"
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network. Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+ return self.serving_output(output)
+
+
+RESNET_START_DOCSTRING = r"""
+ This model is a TensorFlow
+ [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
+ regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+RESNET_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@keras_serializable
+class TFResNetMainLayer(tf.keras.layers.Layer):
+ config_class = ResNetConfig
+
+ def __init__(self, config: ResNetConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.embedder = TFResNetEmbeddings(config, name="embedder")
+ self.encoder = TFResNetEncoder(config, name="encoder")
+ self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TF 2.0 image layers can't use NCHW format when running on CPU.
+ # We transpose to NHWC format and then transpose back after the full forward pass.
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1])
+ embedding_output = self.embedder(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ pooled_output = self.pooler(last_hidden_state)
+
+ # Transpose all the outputs to the NCHW format
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
+ last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2))
+ pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2))
+ hidden_states = ()
+ for hidden_state in encoder_outputs[1:]:
+ hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + hidden_states
+
+ hidden_states = hidden_states if output_hidden_states else None
+
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "The bare ResNet model outputting raw features without any specific head on top.",
+ RESNET_START_DOCSTRING,
+)
+class TFResNetModel(TFResNetPreTrainedModel):
+ def __init__(self, config: ResNetConfig, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+ self.resnet = TFResNetMainLayer(config=config, name="resnet")
+
+ @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ resnet_outputs = self.resnet(
+ pixel_values=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return resnet_outputs
+
+ def serving_output(
+ self, output: TFBaseModelOutputWithPoolingAndNoAttention
+ ) -> TFBaseModelOutputWithPoolingAndNoAttention:
+ # hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=output.last_hidden_state,
+ pooler_output=output.pooler_output,
+ hidden_states=output.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ RESNET_START_DOCSTRING,
+)
+class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: ResNetConfig, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+ self.num_labels = config.num_labels
+ self.resnet = TFResNetMainLayer(config, name="resnet")
+ # classification head
+ self.classifier_layer = (
+ tf.keras.layers.Dense(config.num_labels, name="classifier.1")
+ if config.num_labels > 0
+ else tf.keras.layers.Activation("linear", name="classifier.1")
+ )
+
+ def classifier(self, x: tf.Tensor) -> tf.Tensor:
+ x = tf.keras.layers.Flatten()(x)
+ logits = self.classifier_layer(x)
+ return logits
+
+ @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor = None,
+ labels: tf.Tensor = None,
+ output_hidden_states: bool = None,
+ return_dict: bool = None,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.resnet(
+ pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return (loss,) + output if loss is not None else output
+
+ return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+ def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention:
+ # hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states)
diff --git a/src/transformers/models/retribert/__init__.py b/src/transformers/models/retribert/__init__.py
index e4d383780b66..34cfadfe1a87 100644
--- a/src/transformers/models/retribert/__init__.py
+++ b/src/transformers/models/retribert/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -26,10 +26,20 @@
"tokenization_retribert": ["RetriBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_retribert_fast"] = ["RetriBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_retribert"] = [
"RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"RetriBertModel",
@@ -41,10 +51,20 @@
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .tokenization_retribert import RetriBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_retribert_fast import RetriBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_retribert import (
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
RetriBertModel,
diff --git a/src/transformers/models/retribert/configuration_retribert.py b/src/transformers/models/retribert/configuration_retribert.py
index 1e4feb2a6909..23172cf40ec7 100644
--- a/src/transformers/models/retribert/configuration_retribert.py
+++ b/src/transformers/models/retribert/configuration_retribert.py
@@ -22,7 +22,9 @@
# TODO: upload to AWS
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/config.json",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/retribert/modeling_retribert.py b/src/transformers/models/retribert/modeling_retribert.py
index 5a12c962e292..03ffc92ba659 100644
--- a/src/transformers/models/retribert/modeling_retribert.py
+++ b/src/transformers/models/retribert/modeling_retribert.py
@@ -201,7 +201,7 @@ def forward(
Indices of input sequence tokens in the vocabulary for the documents in a batch.
attention_mask_doc (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on documents padding token indices.
- checkpoint_batch_size (`int`, *optional*, defaults to ```-1`):
+ checkpoint_batch_size (`int`, *optional*, defaults to `-1`):
If greater than 0, uses gradient checkpointing to only compute sequence representation on
`checkpoint_batch_size` examples at a time on the GPU. All query representations are still compared to
all document representations in the batch.
diff --git a/src/transformers/models/retribert/tokenization_retribert.py b/src/transformers/models/retribert/tokenization_retribert.py
index be9a40913fab..b61c0634406a 100644
--- a/src/transformers/models/retribert/tokenization_retribert.py
+++ b/src/transformers/models/retribert/tokenization_retribert.py
@@ -24,7 +24,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/retribert/tokenization_retribert_fast.py b/src/transformers/models/retribert/tokenization_retribert_fast.py
index 43cc3837214b..3451d1224a7a 100644
--- a/src/transformers/models/retribert/tokenization_retribert_fast.py
+++ b/src/transformers/models/retribert/tokenization_retribert_fast.py
@@ -25,10 +25,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py
index 739029eac9b7..2429ba113e8a 100644
--- a/src/transformers/models/roberta/__init__.py
+++ b/src/transformers/models/roberta/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_roberta": ["RobertaTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_roberta_fast"] = ["RobertaTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_roberta"] = [
"ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"RobertaForCausalLM",
@@ -42,7 +59,12 @@
"RobertaPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_roberta"] = [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForCausalLM",
@@ -56,7 +78,12 @@
"TFRobertaPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM",
@@ -73,10 +100,20 @@
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaOnnxConfig
from .tokenization_roberta import RobertaTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_roberta_fast import RobertaTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM,
@@ -89,7 +126,12 @@
RobertaPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
@@ -103,7 +145,12 @@
TFRobertaPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py
index 4a34fa77bc78..ddd6359b36be 100644
--- a/src/transformers/models/roberta/modeling_flax_roberta.py
+++ b/src/transformers/models/roberta/modeling_flax_roberta.py
@@ -21,6 +21,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
+from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
@@ -47,6 +48,8 @@
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
+remat = nn_partitioning.remat
+
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
@@ -183,8 +186,8 @@ def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
- : {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -511,11 +514,20 @@ def __call__(
class FlaxRobertaLayerCollection(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
def setup(self):
- self.layers = [
- FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
- ]
+ if self.gradient_checkpointing:
+ FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
+ self.layers = [
+ FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ else:
+ self.layers = [
+ FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
def __call__(
self,
@@ -538,8 +550,8 @@ def __call__(
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
- {head_mask.shape[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
@@ -549,12 +561,12 @@ def __call__(
layer_outputs = layer(
hidden_states,
attention_mask,
- layer_head_mask=head_mask[i] if head_mask is not None else None,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
+ head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ init_cache,
+ deterministic,
+ output_attentions,
)
hidden_states = layer_outputs[0]
@@ -585,9 +597,14 @@ def __call__(
class FlaxRobertaEncoder(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
def setup(self):
- self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
+ self.layer = FlaxRobertaLayerCollection(
+ self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
def __call__(
self,
@@ -719,11 +736,20 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
+ gradient_checkpointing: bool = False,
**kwargs
):
- module = self.module_class(config=config, dtype=dtype, **kwargs)
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
+ def enable_gradient_checkpointing(self):
+ self._module = self.module_class(
+ config=self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=True,
+ )
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
@@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
+ gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
- self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
+ self.encoder = FlaxRobertaEncoder(
+ self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(
@@ -967,9 +998,15 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
class FlaxRobertaForMaskedLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ add_pooling_layer=False,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
@@ -1034,9 +1071,15 @@ class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
class FlaxRobertaForSequenceClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ dtype=self.dtype,
+ add_pooling_layer=False,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
@@ -1101,9 +1144,14 @@ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
class FlaxRobertaForMultipleChoiceModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
@@ -1181,9 +1229,15 @@ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
class FlaxRobertaForTokenClassificationModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ dtype=self.dtype,
+ add_pooling_layer=False,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
@@ -1255,9 +1309,15 @@ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ dtype=self.dtype,
+ add_pooling_layer=False,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
@@ -1326,9 +1386,15 @@ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
class FlaxRobertaForCausalLMModule(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.roberta = FlaxRobertaModule(
+ config=self.config,
+ add_pooling_layer=False,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
def __call__(
diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py
index 3b5f6a9a6ba3..46add0be5001 100644
--- a/src/transformers/models/roberta/modeling_roberta.py
+++ b/src/transformers/models/roberta/modeling_roberta.py
@@ -20,7 +20,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -36,7 +35,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -83,7 +87,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
@@ -426,7 +430,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py
index 7c39b7334a46..a320664bcea5 100644
--- a/src/transformers/models/roberta/modeling_tf_roberta.py
+++ b/src/transformers/models/roberta/modeling_tf_roberta.py
@@ -463,8 +463,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/roberta/tokenization_roberta.py b/src/transformers/models/roberta/tokenization_roberta.py
index 0d87615c1569..10b28125e92b 100644
--- a/src/transformers/models/roberta/tokenization_roberta.py
+++ b/src/transformers/models/roberta/tokenization_roberta.py
@@ -39,7 +39,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/vocab.json",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/vocab.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt",
@@ -47,7 +49,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/merges.txt",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/merges.txt",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt"
+ ),
},
}
@@ -320,7 +324,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/roberta/tokenization_roberta_fast.py b/src/transformers/models/roberta/tokenization_roberta_fast.py
index 7b774f69f19a..29381404c47f 100644
--- a/src/transformers/models/roberta/tokenization_roberta_fast.py
+++ b/src/transformers/models/roberta/tokenization_roberta_fast.py
@@ -35,7 +35,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/vocab.json",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/vocab.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt",
@@ -43,15 +45,21 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/merges.txt",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/merges.txt",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt"
+ ),
},
"tokenizer_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/tokenizer.json",
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/tokenizer.json",
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/tokenizer.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/tokenizer.json",
- "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/tokenizer.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/tokenizer.json",
+ "roberta-base-openai-detector": (
+ "https://huggingface.co/roberta-base-openai-detector/resolve/main/tokenizer.json"
+ ),
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/tokenizer.json"
+ ),
},
}
@@ -227,8 +235,9 @@ def mask_token(self) -> str:
Roberta tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
comprise the space before the **.
"""
- if self._mask_token is None and self.verbose:
- logger.error("Using mask_token, but it is not set yet.")
+ if self._mask_token is None:
+ if self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
diff --git a/src/transformers/models/roformer/__init__.py b/src/transformers/models/roformer/__init__.py
index ec99c5a3b86a..909259ead601 100644
--- a/src/transformers/models/roformer/__init__.py
+++ b/src/transformers/models/roformer/__init__.py
@@ -17,7 +17,14 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -25,10 +32,20 @@
"tokenization_roformer": ["RoFormerTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_roformer_fast"] = ["RoFormerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_roformer"] = [
"ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"RoFormerForCausalLM",
@@ -44,7 +61,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_roformer"] = [
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRoFormerForCausalLM",
@@ -59,7 +81,12 @@
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_roformer"] = [
"FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaxRoFormerForMaskedLM",
@@ -76,10 +103,20 @@
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
from .tokenization_roformer import RoFormerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_roformer_fast import RoFormerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
RoFormerForCausalLM,
@@ -94,7 +131,12 @@
load_tf_weights_in_roformer,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_roformer import (
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRoFormerForCausalLM,
@@ -108,7 +150,12 @@
TFRoFormerPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_roformer import (
FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaxRoFormerForMaskedLM,
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index 2c5de2bbbe26..ea547ca52d1b 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -27,10 +27,18 @@
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json"
+ ),
# See all RoFormer models at https://huggingface.co/models?filter=roformer
}
diff --git a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
index 33edf59f6bfd..0ab8b671d075 100755
--- a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained BERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained BERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py
index 37dd72966646..011f1610488d 100644
--- a/src/transformers/models/roformer/modeling_flax_roformer.py
+++ b/src/transformers/models/roformer/modeling_flax_roformer.py
@@ -180,8 +180,8 @@ class FlaxRoFormerSelfAttention(nn.Module):
def setup(self) -> None:
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
- : {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -456,8 +456,8 @@ def __call__(
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
- {head_mask.shape[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py
index 738df5111922..120017abdff4 100644
--- a/src/transformers/models/roformer/modeling_roformer.py
+++ b/src/transformers/models/roformer/modeling_roformer.py
@@ -17,7 +17,7 @@
import math
import os
-from typing import Optional
+from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -699,8 +699,8 @@ class RoFormerPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = []
_keys_to_ignore_on_load_unexpected = [
- r"roformer\.embeddings_project\.weight",
- r"roformer\.embeddings_project\.bias",
+ r"roformer.embeddings_project.weight",
+ r"roformer.embeddings_project.bias",
]
def _init_weights(self, module):
@@ -835,19 +835,19 @@ class PreTrainedModel
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -984,18 +984,18 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1080,21 +1080,21 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- head_mask=None,
- cross_attn_head_mask=None,
- past_key_values=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1246,16 +1246,16 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1341,16 +1341,16 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1432,16 +1432,16 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1510,17 +1510,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor]]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index e5e3728c03fc..ac1efc72d089 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -31,10 +31,18 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt"
+ ),
}
}
@@ -144,8 +152,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py
index 59644df74658..7b2cab568862 100644
--- a/src/transformers/models/roformer/tokenization_roformer_fast.py
+++ b/src/transformers/models/roformer/tokenization_roformer_fast.py
@@ -33,10 +33,18 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/segformer/__init__.py b/src/transformers/models/segformer/__init__.py
index fed4e8127cbd..2317237509a0 100644
--- a/src/transformers/models/segformer/__init__.py
+++ b/src/transformers/models/segformer/__init__.py
@@ -17,17 +17,31 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_torch_available,
+ is_vision_available,
+)
-_import_structure = {
- "configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"],
-}
+_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_segformer"] = ["SegformerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_segformer"] = [
"SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegformerDecodeHead",
@@ -38,14 +52,39 @@
"SegformerPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_segformer"] = [
+ "TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSegformerDecodeHead",
+ "TFSegformerForImageClassification",
+ "TFSegformerForSemanticSegmentation",
+ "TFSegformerModel",
+ "TFSegformerPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_segformer import SegformerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_segformer import (
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
SegformerDecodeHead,
@@ -55,7 +94,20 @@
SegformerModel,
SegformerPreTrainedModel,
)
-
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_segformer import (
+ TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSegformerDecodeHead,
+ TFSegformerForImageClassification,
+ TFSegformerForSemanticSegmentation,
+ TFSegformerModel,
+ TFSegformerPreTrainedModel,
+ )
else:
import sys
diff --git a/src/transformers/models/segformer/configuration_segformer.py b/src/transformers/models/segformer/configuration_segformer.py
index fa54c62c227c..faec5d6c4c9f 100644
--- a/src/transformers/models/segformer/configuration_segformer.py
+++ b/src/transformers/models/segformer/configuration_segformer.py
@@ -23,7 +23,9 @@
logger = logging.get_logger(__name__)
SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "nvidia/segformer-b0-finetuned-ade-512-512": "https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/main/config.json",
+ "nvidia/segformer-b0-finetuned-ade-512-512": (
+ "https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/main/config.json"
+ ),
# See all SegFormer models at https://huggingface.co/models?filter=segformer
}
@@ -122,8 +124,8 @@ def __init__(
if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False:
warnings.warn(
- "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be removed, "
- "as the behaviour will default to that of reshape_last_stage = True.",
+ "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be"
+ " removed, as the behaviour will default to that of reshape_last_stage = True.",
FutureWarning,
)
diff --git a/src/transformers/models/segformer/feature_extraction_segformer.py b/src/transformers/models/segformer/feature_extraction_segformer.py
index c706c559af3c..0a9ae01ef121 100644
--- a/src/transformers/models/segformer/feature_extraction_segformer.py
+++ b/src/transformers/models/segformer/feature_extraction_segformer.py
@@ -158,8 +158,9 @@ def __call__(
if not valid_segmentation_maps:
raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
+ " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
+ " examples)."
)
is_batched = bool(
diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py
index 55ac976b3544..b8be4cdb70a6 100755
--- a/src/transformers/models/segformer/modeling_segformer.py
+++ b/src/transformers/models/segformer/modeling_segformer.py
@@ -86,21 +86,23 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput):
# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
+def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
"""
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
- DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
- Connect' is a different form of dropout in a separate paper... See discussion:
- https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
- argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
"""
if drop_prob == 0.0 or not training:
- return x
+ return input
keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
+ output = input.div(keep_prob) * random_tensor
return output
@@ -108,13 +110,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=T
class SegformerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None):
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
class SegformerOverlapPatchEmbeddings(nn.Module):
"""Construct the overlapping patch embeddings."""
@@ -780,6 +785,8 @@ def forward(
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
+ >>> list(logits.shape)
+ [1, 150, 128, 128]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
@@ -799,7 +806,7 @@ def forward(
loss = None
if labels is not None:
- if self.config.num_labels == 1:
+ if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
# upsample logits to the images' original size
diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py
new file mode 100644
index 000000000000..c2f4b2ff0c7c
--- /dev/null
+++ b/src/transformers/models/segformer/modeling_tf_segformer.py
@@ -0,0 +1,900 @@
+# coding=utf-8
+# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TensorFlow SegFormer model."""
+
+import math
+from typing import Dict, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput
+from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import logging
+from .configuration_segformer import SegformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "SegformerConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "SegformerFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
+_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "nvidia/segformer-b0-finetuned-ade-512-512",
+ # See all SegFormer models at https://huggingface.co/models?filter=segformer
+]
+
+
+# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer
+class TFSegformerDropPath(tf.keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
+ """Construct the overlapping patch embeddings."""
+
+ def __init__(self, patch_size, stride, hidden_size, **kwargs):
+ super().__init__(**kwargs)
+ self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2)
+ self.proj = tf.keras.layers.Conv2D(
+ filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
+ )
+
+ self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
+
+ def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
+ embeddings = self.proj(self.padding(pixel_values))
+ height = shape_list(embeddings)[1]
+ width = shape_list(embeddings)[2]
+ hidden_dim = shape_list(embeddings)[3]
+ # (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels)
+ # this can be fed to a Transformer layer
+ embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim))
+ embeddings = self.layer_norm(embeddings)
+ return embeddings, height, width
+
+
+class TFSegformerEfficientSelfAttention(tf.keras.layers.Layer):
+ """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
+ paper](https://arxiv.org/abs/2102.12122)."""
+
+ def __init__(
+ self,
+ config: SegformerConfig,
+ hidden_size: int,
+ num_attention_heads: int,
+ sequence_reduction_ratio: int,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+
+ if self.hidden_size % self.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({self.num_attention_heads})"
+ )
+
+ self.attention_head_size = self.hidden_size // self.num_attention_heads
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = tf.keras.layers.Dense(self.all_head_size, name="query")
+ self.key = tf.keras.layers.Dense(self.all_head_size, name="key")
+ self.value = tf.keras.layers.Dense(self.all_head_size, name="value")
+
+ self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
+
+ self.sr_ratio = sequence_reduction_ratio
+ if sequence_reduction_ratio > 1:
+ self.sr = tf.keras.layers.Conv2D(
+ filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr"
+ )
+ self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
+
+ def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size]
+ # to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ batch_size = shape_list(tensor)[0]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size]
+ # to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ height: int,
+ width: int,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
+ batch_size = shape_list(hidden_states)[0]
+ num_channels = shape_list(hidden_states)[2]
+
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ if self.sr_ratio > 1:
+ # Reshape to (batch_size, height, width, num_channels)
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
+ # Apply sequence reduction
+ hidden_states = self.sr(hidden_states)
+ # Reshape back to (batch_size, seq_len, num_channels)
+ hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels))
+ hidden_states = self.layer_norm(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+
+ scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, scale)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
+ # (batch_size, seq_len_q, all_head_size)
+ context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+ return outputs
+
+
+class TFSegformerSelfOutput(tf.keras.layers.Layer):
+ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSegformerAttention(tf.keras.layers.Layer):
+ def __init__(
+ self,
+ config: SegformerConfig,
+ hidden_size: int,
+ num_attention_heads: int,
+ sequence_reduction_ratio: int,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.self = TFSegformerEfficientSelfAttention(
+ config=config,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ sequence_reduction_ratio=sequence_reduction_ratio,
+ name="self",
+ )
+ self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output")
+
+ def call(
+ self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False
+ ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
+ self_outputs = self.self(hidden_states, height, width, output_attentions)
+
+ attention_output = self.dense_output(self_outputs[0])
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class TFSegformerDWConv(tf.keras.layers.Layer):
+ def __init__(self, dim: int = 768, **kwargs):
+ super().__init__(**kwargs)
+ self.depthwise_convolution = tf.keras.layers.Conv2D(
+ filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
+ )
+
+ def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ batch_size = shape_list(hidden_states)[0]
+ num_channels = shape_list(hidden_states)[-1]
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
+ hidden_states = self.depthwise_convolution(hidden_states)
+
+ new_height = shape_list(hidden_states)[1]
+ new_width = shape_list(hidden_states)[2]
+ num_channels = shape_list(hidden_states)[3]
+ hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
+ return hidden_states
+
+
+class TFSegformerMixFFN(tf.keras.layers.Layer):
+ def __init__(
+ self,
+ config: SegformerConfig,
+ in_features: int,
+ hidden_features: int = None,
+ out_features: int = None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ out_features = out_features or in_features
+ self.dense1 = tf.keras.layers.Dense(hidden_features, name="dense1")
+ self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv")
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
+
+ def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense1(hidden_states)
+ hidden_states = self.depthwise_convolution(hidden_states, height, width)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = self.dense2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSegformerLayer(tf.keras.layers.Layer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(
+ self,
+ config,
+ hidden_size: int,
+ num_attention_heads: int,
+ drop_path: float,
+ sequence_reduction_ratio: int,
+ mlp_ratio: int,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1")
+ self.attention = TFSegformerAttention(
+ config,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ sequence_reduction_ratio=sequence_reduction_ratio,
+ name="attention",
+ )
+ self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else tf.keras.layers.Activation("linear")
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
+ mlp_hidden_size = int(hidden_size * mlp_ratio)
+ self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ height: int,
+ width: int,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> Tuple:
+ self_attention_outputs = self.attention(
+ self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
+ height,
+ width,
+ output_attentions=output_attentions,
+ training=training,
+ )
+
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection (with stochastic depth)
+ attention_output = self.drop_path(attention_output, training=training)
+ hidden_states = attention_output + hidden_states
+ mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
+
+ # second residual connection (with stochastic depth)
+ mlp_output = self.drop_path(mlp_output, training=training)
+ layer_output = mlp_output + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class TFSegformerEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: SegformerConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ # stochastic depth decay rule
+ drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
+
+ # patch embeddings
+ embeddings = []
+ for i in range(config.num_encoder_blocks):
+ embeddings.append(
+ TFSegformerOverlapPatchEmbeddings(
+ patch_size=config.patch_sizes[i],
+ stride=config.strides[i],
+ hidden_size=config.hidden_sizes[i],
+ name=f"patch_embeddings.{i}",
+ )
+ )
+ self.embeddings = embeddings
+
+ # Transformer blocks
+ blocks = []
+ cur = 0
+ for i in range(config.num_encoder_blocks):
+ # each block consists of layers
+ layers = []
+ if i != 0:
+ cur += config.depths[i - 1]
+ for j in range(config.depths[i]):
+ layers.append(
+ TFSegformerLayer(
+ config,
+ hidden_size=config.hidden_sizes[i],
+ num_attention_heads=config.num_attention_heads[i],
+ drop_path=drop_path_decays[cur + j],
+ sequence_reduction_ratio=config.sr_ratios[i],
+ mlp_ratio=config.mlp_ratios[i],
+ name=f"block.{i}.{j}",
+ )
+ )
+ blocks.append(layers)
+
+ self.block = blocks
+
+ # Layer norms
+ self.layer_norms = [
+ tf.keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}")
+ for i in range(config.num_encoder_blocks)
+ ]
+
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ training: bool = False,
+ ) -> Union[Tuple, TFBaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ batch_size = shape_list(pixel_values)[0]
+
+ hidden_states = pixel_values
+ for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
+ embedding_layer, block_layer, norm_layer = x
+ # first, obtain patch embeddings
+ hidden_states, height, width = embedding_layer(hidden_states)
+
+ # second, send embeddings through blocks
+ # (each block consists of multiple layers i.e., list of layers)
+ for i, blk in enumerate(block_layer):
+ layer_outputs = blk(
+ hidden_states,
+ height,
+ width,
+ output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ # third, apply layer norm
+ hidden_states = norm_layer(hidden_states)
+
+ # fourth, optionally reshape back to (batch_size, height, width, num_channels)
+ if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage):
+ num_channels = shape_list(hidden_states)[-1]
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
+ )
+
+
+@keras_serializable
+class TFSegformerMainLayer(tf.keras.layers.Layer):
+ config_class = SegformerConfig
+
+ def __init__(self, config: SegformerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ # hierarchical Transformer encoder
+ self.encoder = TFSegformerEncoder(config, name="encoder")
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple, TFBaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ encoder_outputs = self.encoder(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = encoder_outputs[0]
+ # Change to NCHW output format to have uniformity in the modules
+ sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
+
+ if not return_dict:
+ if tf.greater(len(encoder_outputs[1:]), 0):
+ transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0])
+ return (sequence_output,) + (transposed_encoder_outputs,)
+ else:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return TFBaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class TFSegformerPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SegformerConfig
+ base_model_prefix = "segformer"
+ main_input_name = "pixel_values"
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network.
+
+ Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 512, 512), dtype=tf.float32)
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ """
+ Method used for serving the model.
+
+ Args:
+ inputs (`Dict[str, tf.Tensor]`):
+ The input of the saved model as a dictionary of tensors.
+ """
+ output = self.call(inputs)
+
+ return self.serving_output(output)
+
+
+SEGFORMER_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SEGFORMER_INPUTS_DOCSTRING = r"""
+
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
+ SEGFORMER_START_DOCSTRING,
+)
+class TFSegformerModel(TFSegformerPreTrainedModel):
+ def __init__(self, config: SegformerConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+
+ # hierarchical Transformer encoder
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple, TFBaseModelOutput]:
+ outputs = self.segformer(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput:
+ # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFBaseModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
+ states) e.g. for ImageNet.
+ """,
+ SEGFORMER_START_DOCSTRING,
+)
+class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: SegformerConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
+
+ # Classifier head
+ self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TFSequenceClassifierOutput]:
+ outputs = self.segformer(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ # convert last hidden states to (batch_size, height*width, hidden_size)
+ batch_size = shape_list(sequence_output)[0]
+ sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
+ sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1]))
+
+ # global average pooling
+ sequence_output = tf.reduce_mean(sequence_output, axis=1)
+
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
+ # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSequenceClassifierOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
+
+
+class TFSegformerMLP(tf.keras.layers.Layer):
+ """
+ Linear Embedding.
+ """
+
+ def __init__(self, config: SegformerConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.proj = tf.keras.layers.Dense(config.decoder_hidden_size, name="proj")
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ height = shape_list(hidden_states)[1]
+ width = shape_list(hidden_states)[2]
+ hidden_dim = shape_list(hidden_states)[-1]
+ hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim))
+ hidden_states = self.proj(hidden_states)
+ return hidden_states
+
+
+class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
+ def __init__(self, config: SegformerConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
+ mlps = []
+ for i in range(config.num_encoder_blocks):
+ mlp = TFSegformerMLP(config, name=f"linear_c.{i}")
+ mlps.append(mlp)
+ self.mlps = mlps
+
+ # the following 3 layers implement the ConvModule of the original implementation
+ self.linear_fuse = tf.keras.layers.Conv2D(
+ filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
+ )
+ self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="batch_norm")
+ self.activation = tf.keras.layers.Activation("relu")
+
+ self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
+ self.classifier = tf.keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier")
+
+ self.config = config
+
+ def call(self, encoder_hidden_states, training: bool = False):
+ batch_size = shape_list(encoder_hidden_states[-1])[0]
+
+ all_hidden_states = ()
+ for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
+ if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
+ height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
+ height = width = tf.cast(height, tf.int32)
+ encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
+
+ # unify channel dimension
+ encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
+ height = shape_list(encoder_hidden_state)[1]
+ width = shape_list(encoder_hidden_state)[2]
+ encoder_hidden_state = mlp(encoder_hidden_state)
+ encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
+
+ # upsample
+ temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
+ upsample_resolution = shape_list(temp_state)[1:-1]
+ encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear")
+ all_hidden_states += (encoder_hidden_state,)
+
+ hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
+ hidden_states = self.batch_norm(hidden_states, training=training)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # logits of shape (batch_size, height/4, width/4, num_labels)
+ logits = self.classifier(hidden_states)
+
+ return logits
+
+
+@add_start_docstrings(
+ """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
+ SEGFORMER_START_DOCSTRING,
+)
+class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
+ def __init__(self, config: SegformerConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.segformer = TFSegformerMainLayer(config, name="segformer")
+ self.decode_head = TFSegformerDecodeHead(config, name="decode_head")
+
+ def hf_compute_loss(self, logits, labels):
+ # upsample logits to the images' original size
+ # `labels` is of shape (batch_size, height, width)
+ label_interp_shape = shape_list(labels)[1:]
+
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
+ # compute weighted loss
+ loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
+
+ def masked_loss(real, pred):
+ unmasked_loss = loss_fct(real, pred)
+ mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
+ masked_loss = unmasked_loss * mask
+ # Reduction strategy in the similar spirit with
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
+ return tf.reshape(reduced_masked_loss, (1,))
+
+ return masked_loss(labels, upsampled_logits)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TFSemanticSegmenterOutput]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed
+ (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
+ >>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs, training=False)
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ >>> list(logits.shape)
+ [1, 150, 128, 128]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.segformer(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ logits = self.decode_head(encoder_hidden_states)
+
+ loss = None
+ if labels is not None:
+ if not self.config.num_labels > 1:
+ raise ValueError("The number of labels should be greater than one")
+ else:
+ loss = self.hf_compute_loss(logits=logits, labels=labels)
+
+ # make logits of shape (batch_size, num_labels, height, width) to
+ # keep them consistent across APIs
+ logits = tf.transpose(logits, perm=[0, 3, 1, 2])
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput:
+ # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSemanticSegmenterOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
diff --git a/src/transformers/models/sew/__init__.py b/src/transformers/models/sew/__init__.py
index 4ee9380137d1..bfe39bea1bdc 100644
--- a/src/transformers/models/sew/__init__.py
+++ b/src/transformers/models/sew/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"],
-}
+_import_structure = {"configuration_sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_sew"] = [
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWForCTC",
@@ -36,7 +39,12 @@
if TYPE_CHECKING:
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_sew import (
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWForCTC,
diff --git a/src/transformers/models/sew/configuration_sew.py b/src/transformers/models/sew/configuration_sew.py
index ad6a6afa6992..c955c0e48fe3 100644
--- a/src/transformers/models/sew/configuration_sew.py
+++ b/src/transformers/models/sew/configuration_sew.py
@@ -76,15 +76,15 @@ class SEWConfig(PretrainedConfig):
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
diff --git a/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
index 6449288810f4..58c0338a850d 100644
--- a/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
@@ -67,9 +67,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -137,28 +138,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index ac2a6293cb95..632f7d4880f1 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -174,7 +174,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -489,7 +489,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -505,7 +506,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -526,7 +528,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -640,7 +643,8 @@ def forward(
attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
diff --git a/src/transformers/models/sew_d/__init__.py b/src/transformers/models/sew_d/__init__.py
index bc5774004057..905bfb0f5b68 100644
--- a/src/transformers/models/sew_d/__init__.py
+++ b/src/transformers/models/sew_d/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"],
-}
+_import_structure = {"configuration_sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_sew_d"] = [
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWDForCTC",
@@ -36,7 +39,12 @@
if TYPE_CHECKING:
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_sew_d import (
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWDForCTC,
diff --git a/src/transformers/models/sew_d/configuration_sew_d.py b/src/transformers/models/sew_d/configuration_sew_d.py
index 996338cb0f05..8461dfef4511 100644
--- a/src/transformers/models/sew_d/configuration_sew_d.py
+++ b/src/transformers/models/sew_d/configuration_sew_d.py
@@ -94,15 +94,15 @@ class SEWDConfig(PretrainedConfig):
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
diff --git a/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
index e6529eea04dd..942add470b9c 100644
--- a/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
@@ -69,9 +69,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -141,28 +142,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py
index a297e4c7b25b..a9a231aec1d8 100644
--- a/src/transformers/models/sew_d/modeling_sew_d.py
+++ b/src/transformers/models/sew_d/modeling_sew_d.py
@@ -175,7 +175,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -261,7 +261,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
- mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
@@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
- rmask = ~(mask.bool())
+ rmask = ~(mask.to(torch.bool))
- output = input.masked_fill(rmask, float("-inf"))
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
@@ -557,7 +557,9 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
- output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
+ output = masked_fill(
+ g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
+ )
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
@@ -593,6 +595,23 @@ def backward(ctx, grad_output):
else:
return grad_output, None
+ @staticmethod
+ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
+ from torch.onnx import symbolic_opset12
+
+ dropout_p = local_ctx
+ if isinstance(local_ctx, DropoutContext):
+ dropout_p = local_ctx.dropout
+ # StableDropout only calls this function when training.
+ train = True
+ # TODO: We should check if the opset_version being used to export
+ # is > 12 here, but there's no good way to do that. As-is, if the
+ # opset_version < 12, export will fail with a CheckerError.
+ # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
+ # if opset_version < 12:
+ # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
+ return symbolic_opset12.dropout(g, input, dropout_p, train)
+
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module):
@@ -711,7 +730,7 @@ def __init__(self, config):
def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
@@ -792,7 +811,7 @@ def forward(
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
diff --git a/src/transformers/models/speech_encoder_decoder/__init__.py b/src/transformers/models/speech_encoder_decoder/__init__.py
index a040990864a9..4eea93eacddc 100644
--- a/src/transformers/models/speech_encoder_decoder/__init__.py
+++ b/src/transformers/models/speech_encoder_decoder/__init__.py
@@ -18,26 +18,44 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
-_import_structure = {
- "configuration_speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
-}
+_import_structure = {"configuration_speech_encoder_decoder": ["SpeechEncoderDecoderConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
else:
diff --git a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
index ca3e4966aaf9..8b648f8e21bc 100644
--- a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
@@ -77,7 +77,8 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
- f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and"
+ f" `decoder` sub-configurations are passed, but only {kwargs}"
)
encoder_config = kwargs.pop("encoder")
diff --git a/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
index 3c25ab706f4e..8680f96e50d5 100644
--- a/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
+++ b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
@@ -75,9 +75,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -147,28 +148,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py b/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
index 40433bba1344..0a4bc48dea32 100644
--- a/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
+++ b/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
@@ -77,9 +77,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -153,28 +154,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
index 0326fee63eea..cd304fa0c0a8 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
@@ -357,10 +357,10 @@ def __init__(
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# make sure input & output embeddings are not tied
@@ -389,7 +389,8 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
- f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+ f" and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
@@ -713,7 +714,8 @@ def __call__(
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
- "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+ " be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
@@ -895,10 +897,9 @@ def from_encoder_decoder_pretrained(
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
index 1dbba59f9ef3..388be2449947 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
@@ -135,7 +135,7 @@
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
- input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
@@ -199,10 +199,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -221,11 +221,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -410,10 +412,9 @@ def from_encoder_decoder_pretrained(
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -481,8 +482,7 @@ def forward(
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heiĆen zu kƶnnen.'
>>> # Training: Train model on English transcription
- >>> with processor.as_target_processor():
- ... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
+ >>> labels = processor(text=ds[0]["text"], return_tensors="pt").input_ids
>>> loss = model(input_values, labels=labels).loss
>>> loss.backward()
@@ -599,8 +599,8 @@ def prepare_inputs_for_generation(
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. "
- "Please use the respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py
index 0cccf6672136..20eba2bf6a2d 100644
--- a/src/transformers/models/speech_to_text/__init__.py
+++ b/src/transformers/models/speech_to_text/__init__.py
@@ -17,26 +17,45 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_speech_to_text": [
- "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "Speech2TextConfig",
- ],
+ "configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
-if is_speech_available():
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
if is_sentencepiece_available():
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_speech_to_text"] = [
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSpeech2TextForConditionalGeneration",
@@ -44,7 +63,12 @@
"TFSpeech2TextPreTrainedModel",
]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_to_text"] = [
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
@@ -56,16 +80,31 @@
if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_speech_to_text import Speech2TextTokenizer
- if is_speech_available():
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available():
from .processing_speech_to_text import Speech2TextProcessor
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration,
@@ -73,7 +112,12 @@
TFSpeech2TextPreTrainedModel,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py
index f08bbf51e1b2..f12be50b538c 100644
--- a/src/transformers/models/speech_to_text/configuration_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json",
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json"
+ ),
# See all Speech2Text models at https://huggingface.co/models?filter=speech_to_text
}
diff --git a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
index df8bc485364f..6c1cd993fe46 100644
--- a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
+++ b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
@@ -102,7 +102,8 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
]
):
raise ValueError(
- f"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing, but all the following weights are missing {missing}"
+ "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
+ f" but all the following weights are missing {missing}"
)
if tie_embeds:
diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
index e6ff52f18360..4294c48c71f0 100644
--- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
@@ -190,8 +190,9 @@ def __call__(
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}. "
- f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}."
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
index 7c2e1835370a..a5a2998f22c9 100755
--- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module):
@@ -292,7 +292,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -308,7 +309,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -329,7 +331,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -596,7 +599,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma
SPEECH_TO_TEXT_INPUTS_DOCSTRING = r"""
Args:
- input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`):
+ input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.*
via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
@@ -625,9 +628,9 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
- If you want to change padding behavior, you should read [`modeling_speech_to_text._prepare_decoder_inputs`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
+ If you want to change padding behavior, you should read
+ [`modeling_speech_to_text._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@@ -660,8 +663,8 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. decoder_inputs_embeds (`torch.FloatTensor`
- of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. decoder_inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`decoder_input_ids` you can choose to directly pass an embedded representation. If `past_key_values` is
used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is
useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
@@ -885,7 +888,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -962,8 +965,8 @@ def forward(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
- of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
@@ -1024,9 +1027,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
@@ -1041,7 +1045,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -1247,8 +1252,8 @@ def forward(
class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
- r"encoder\.version",
- r"decoder\.version",
+ r"encoder.version",
+ r"decoder.version",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
index c78d19056bd3..dd575575de6d 100755
--- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
@@ -90,7 +90,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
- bsz, tgt_len = input_ids_shape
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@@ -103,7 +104,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -142,7 +143,8 @@ def __init__(self, config: Speech2TextConfig, **kwargs):
]
def call(self, input_features: tf.Tensor) -> tf.Tensor:
- hidden_states = tf.identity(input_features) # TF Conv1D assumes Batch x Time x Channels, same as the input
+ # TF Conv1D assumes Batch x Time x Channels, same as the input
+ hidden_states = tf.cast(input_features, tf.float32)
for i, conv in enumerate(self.conv_layers):
# equivalent to `padding=k // 2` on PT's `nn.Conv1d`
pad_len = self.kernel_sizes[i] // 2
@@ -186,23 +188,20 @@ def _get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optiona
# zero pad
emb = tf.concat([emb, tf.zeros(num_embeddings, 1)], axis=1)
if padding_idx is not None:
- emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, emb.shape[1])), emb[padding_idx + 1 :, :]], axis=0)
+ emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, tf.shape(emb)[1])), emb[padding_idx + 1 :, :]], axis=0)
return emb
- def _resize_embeddings(self):
- """Recreates (and effectivelly resizes) the sinusoidal embeddings"""
- self.embeddings = self.add_weight(
- name="weights", # name also used in PT
- shape=self.embedding_weights.shape,
- )
- self.embeddings.assign(self.embedding_weights)
-
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
- self._resize_embeddings()
+ self.embeddings = self.add_weight(
+ name="weights", # name also used in PT
+ shape=tf.shape(self.embedding_weights),
+ trainable=False,
+ )
+ self.embeddings.assign(self.embedding_weights)
super().build(input_shape)
def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tensor:
@@ -214,7 +213,7 @@ def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tens
max_pos = self.padding_idx + 1 + seq_len
if max_pos > shape_list(self.embeddings)[0]:
self.embedding_weights = self._get_embedding(max_pos + self.offset, self.embedding_dim, self.padding_idx)
- self._resize_embeddings()
+ self.embeddings.assign(self.embedding_weights)
return tf.reshape(tf.gather(self.embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1))
@staticmethod
@@ -331,7 +330,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -341,7 +343,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -357,7 +362,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -374,7 +382,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -595,7 +606,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
@tf.function(
input_signature=[
{
- "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "input_features": tf.TensorSpec((None, None, None), tf.float32, name="input_features"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
@@ -778,7 +789,6 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma
),
axis=-1,
)
-
attention_mask = tf.scatter_nd(indices=indices, updates=tf.ones(bsz), shape=[bsz, feature_vector_length])
attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64)
return attention_mask
@@ -832,10 +842,10 @@ def call(
# subsample attention mask if necessary
if attention_mask is not None:
- attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
+ attention_mask = self._get_feature_vector_attention_mask(tf.shape(inputs_embeds)[1], attention_mask)
padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64)
else:
- padding_mask = tf.zeros(inputs_embeds.shape[:-1], dtype=tf.int64)
+ padding_mask = tf.zeros(tf.shape(inputs_embeds)[:-1], dtype=tf.int64)
embed_pos = self.embed_positions(padding_mask)
@@ -856,7 +866,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
for idx, encoder_layer in enumerate(self.layers):
@@ -926,22 +939,6 @@ def get_embed_tokens(self):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
- combined_attention_mask = (
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
@unpack_inputs
def call(
self,
@@ -1005,11 +1002,11 @@ def call(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
- `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
- control over how to convert `input_ids` indices into associated vectors than the model's internal
- embedding lookup matrix.
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`tf.Tensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids`
+ you can choose to directly pass an embedded representation. This is useful if you want more control
+ over how to convert `input_ids` indices into associated vectors than the model's internal embedding
+ lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@@ -1037,9 +1034,16 @@ def call(
else:
inputs_embeds = inputs_embeds
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
@@ -1065,7 +1069,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1081,7 +1088,7 @@ def call(
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
- attention_mask=attention_mask,
+ attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
@@ -1184,7 +1191,7 @@ def call(
# downsample encoder attention mask
if attention_mask is not None:
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
- encoder_outputs[0].shape[1], attention_mask
+ tf.shape(encoder_outputs[0])[1], attention_mask
)
else:
encoder_attention_mask = None
@@ -1313,6 +1320,8 @@ def __init__(self, config: Speech2TextConfig):
super().__init__(config)
self.model = TFSpeech2TextMainLayer(config, name="model")
self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head")
+ # TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate
+ self.supports_xla_generation = False
def get_encoder(self):
return self.model.encoder
@@ -1444,8 +1453,8 @@ def serving_output(self, output):
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
- return TFSeq2SeqModelOutput(
- last_hidden_state=output.last_hidden_state,
+ return TFSeq2SeqLMOutput(
+ logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
diff --git a/src/transformers/models/speech_to_text/processing_speech_to_text.py b/src/transformers/models/speech_to_text/processing_speech_to_text.py
index 969df9d108fe..3f047932030f 100644
--- a/src/transformers/models/speech_to_text/processing_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/processing_speech_to_text.py
@@ -15,6 +15,7 @@
"""
Speech processor class for Speech2Text
"""
+import warnings
from contextlib import contextmanager
from ...processing_utils import ProcessorMixin
@@ -41,6 +42,7 @@ class Speech2TextProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
"""
@@ -50,7 +52,35 @@ def __call__(self, *args, **kwargs):
[`~Speech2TextTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
information.
"""
- return self.current_processor(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
def batch_decode(self, *args, **kwargs):
"""
@@ -72,6 +102,13 @@ def as_target_processor(self):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Speech2Text.
"""
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
index 7d77c945ced8..e1bc681499f7 100644
--- a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
@@ -36,10 +36,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json",
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json"
+ ),
},
"spm_file": {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model"
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model"
+ )
},
}
diff --git a/src/transformers/models/speech_to_text_2/__init__.py b/src/transformers/models/speech_to_text_2/__init__.py
index d4ea8d037a0d..645a39746093 100644
--- a/src/transformers/models/speech_to_text_2/__init__.py
+++ b/src/transformers/models/speech_to_text_2/__init__.py
@@ -17,20 +17,28 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_speech_to_text_2": [
- "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "Speech2Text2Config",
- ],
+ "configuration_speech_to_text_2": ["SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2Text2Config"],
"processing_speech_to_text_2": ["Speech2Text2Processor"],
"tokenization_speech_to_text_2": ["Speech2Text2Tokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_to_text_2"] = [
"SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2Text2ForCausalLM",
@@ -43,7 +51,12 @@
from .processing_speech_to_text_2 import Speech2Text2Processor
from .tokenization_speech_to_text_2 import Speech2Text2Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_to_text_2 import (
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2Text2ForCausalLM,
diff --git a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
index d27bad73c73c..c1b3cf7e4c7f 100644
--- a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/config.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/config.json"
+ ),
# See all Speech2Text models at https://huggingface.co/models?filter=speech2text2
}
diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
index dccbd2adf48b..9dc22e11a22e 100755
--- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
@@ -238,7 +238,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -254,7 +255,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -275,7 +277,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -492,7 +495,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -569,8 +572,8 @@ def forward(
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor`
- of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
@@ -633,7 +636,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -649,7 +653,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -735,7 +740,8 @@ def forward(self, *args, **kwargs):
@add_start_docstrings(
- "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].",
+ "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of"
+ " [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].",
SPEECH_TO_TEXT_2_START_DOCSTRING,
)
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
diff --git a/src/transformers/models/speech_to_text_2/processing_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/processing_speech_to_text_2.py
index 28189ba88198..c40831d0214a 100644
--- a/src/transformers/models/speech_to_text_2/processing_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/processing_speech_to_text_2.py
@@ -15,6 +15,7 @@
"""
Speech processor class for Speech2Text2
"""
+import warnings
from contextlib import contextmanager
from ...processing_utils import ProcessorMixin
@@ -40,6 +41,7 @@ class Speech2Text2Processor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
"""
@@ -49,7 +51,35 @@ def __call__(self, *args, **kwargs):
Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the doctsring of the above two
methods for more information.
"""
- return self.current_processor(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
def batch_decode(self, *args, **kwargs):
"""
@@ -71,6 +101,13 @@ def as_target_processor(self):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Speech2Text2.
"""
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
index 51d5c31ec991..3365dfe382ae 100644
--- a/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
@@ -33,13 +33,19 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/vocab.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/vocab.json"
+ ),
},
"tokenizer_config_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/tokenizer_config.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/tokenizer_config.json"
+ ),
},
"merges_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/merges.txt",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/merges.txt"
+ ),
},
}
@@ -244,7 +250,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
if self.bpe_ranks is None:
diff --git a/src/transformers/models/splinter/__init__.py b/src/transformers/models/splinter/__init__.py
index 6a2308bbf535..9f056d7200a1 100644
--- a/src/transformers/models/splinter/__init__.py
+++ b/src/transformers/models/splinter/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -25,13 +25,24 @@
"tokenization_splinter": ["SplinterTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_splinter_fast"] = ["SplinterTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_splinter"] = [
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForQuestionAnswering",
+ "SplinterForPreTraining",
"SplinterLayer",
"SplinterModel",
"SplinterPreTrainedModel",
@@ -42,12 +53,23 @@
from .configuration_splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig
from .tokenization_splinter import SplinterTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_splinter_fast import SplinterTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py
index 840aa07b87ea..1f94f6f9ad27 100755
--- a/src/transformers/models/splinter/modeling_splinter.py
+++ b/src/transformers/models/splinter/modeling_splinter.py
@@ -16,6 +16,7 @@
import math
+from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@@ -24,7 +25,7 @@
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
-from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
+from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
@@ -370,7 +371,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -908,8 +910,8 @@ def forward(
start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
if attention_mask is not None:
- start_logits = start_logits + (1 - attention_mask) * -10000.0
- end_logits = end_logits + (1 - attention_mask) * -10000.0
+ start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
+ end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
total_loss = None
if start_positions is not None and end_positions is not None:
@@ -939,3 +941,171 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+
+@dataclass
+class SplinterForPreTrainingOutput(ModelOutput):
+ """
+ Class for outputs of Splinter as a span selection model.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: torch.FloatTensor = None
+ end_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@add_start_docstrings(
+ """
+ Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
+ is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
+ instead.
+ """,
+ SPLINTER_START_DOCSTRING,
+)
+class SplinterForPreTraining(SplinterPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.splinter = SplinterModel(config)
+ self.splinter_qass = QuestionAwareSpanSelectionHead(config)
+ self.question_token_id = config.question_token_id
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(
+ SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ question_positions: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, SplinterForPreTrainingOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
+ num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
+ the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
+ sequence_length)`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if question_positions is None and start_positions is not None and end_positions is not None:
+ raise TypeError("question_positions must be specified in order to calculate the loss")
+
+ elif question_positions is None and input_ids is None:
+ raise TypeError("question_positions must be specified when input_embeds is used")
+
+ elif question_positions is None:
+ question_positions = self._prepare_question_positions(input_ids)
+
+ outputs = self.splinter(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ batch_size, sequence_length, dim = sequence_output.size()
+ # [batch_size, num_questions, sequence_length]
+ start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
+
+ num_questions = question_positions.size(1)
+ if attention_mask is not None:
+ attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
+ batch_size, num_questions, sequence_length
+ )
+ start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
+ end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
+
+ total_loss = None
+ # [batch_size, num_questions, sequence_length]
+ if start_positions is not None and end_positions is not None:
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ start_positions.clamp_(0, max(0, sequence_length - 1))
+ end_positions.clamp_(0, max(0, sequence_length - 1))
+
+ # Ignore zero positions in the loss. Splinter never predicts zero
+ # during pretraining and zero is used for padding question
+ # tokens as well as for start and end positions of padded
+ # question tokens.
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
+ start_loss = loss_fct(
+ start_logits.view(batch_size * num_questions, sequence_length),
+ start_positions.view(batch_size * num_questions),
+ )
+ end_loss = loss_fct(
+ end_logits.view(batch_size * num_questions, sequence_length),
+ end_positions.view(batch_size * num_questions),
+ )
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return SplinterForPreTrainingOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
+ rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
+ num_questions = torch.bincount(rows)
+ positions = torch.full(
+ (input_ids.size(0), num_questions.max()),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ cols = torch.cat([torch.arange(n) for n in num_questions])
+ positions[rows, cols] = flat_positions
+ return positions
diff --git a/src/transformers/models/splinter/tokenization_splinter.py b/src/transformers/models/splinter/tokenization_splinter.py
index 9649da03f9f1..f600566e6e94 100644
--- a/src/transformers/models/splinter/tokenization_splinter.py
+++ b/src/transformers/models/splinter/tokenization_splinter.py
@@ -153,8 +153,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/squeezebert/__init__.py b/src/transformers/models/squeezebert/__init__.py
index 433b9f93343f..9f758bebe024 100644
--- a/src/transformers/models/squeezebert/__init__.py
+++ b/src/transformers/models/squeezebert/__init__.py
@@ -18,18 +18,32 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
- "configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"],
+ "configuration_squeezebert": [
+ "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "SqueezeBertConfig",
+ "SqueezeBertOnnxConfig",
+ ],
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_squeezebert_fast"] = ["SqueezeBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_squeezebert"] = [
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SqueezeBertForMaskedLM",
@@ -44,13 +58,27 @@
if TYPE_CHECKING:
- from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
+ from .configuration_squeezebert import (
+ SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ SqueezeBertConfig,
+ SqueezeBertOnnxConfig,
+ )
from .tokenization_squeezebert import SqueezeBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertForMaskedLM,
diff --git a/src/transformers/models/squeezebert/configuration_squeezebert.py b/src/transformers/models/squeezebert/configuration_squeezebert.py
index 5a77495fc704..41b47ff5750e 100644
--- a/src/transformers/models/squeezebert/configuration_squeezebert.py
+++ b/src/transformers/models/squeezebert/configuration_squeezebert.py
@@ -13,17 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" SqueezeBERT model configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/config.json",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/config.json"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/config.json",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/config.json",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/config.json"
+ ),
}
@@ -150,3 +157,20 @@ def __init__(
self.post_attention_groups = post_attention_groups
self.intermediate_groups = intermediate_groups
self.output_groups = output_groups
+
+
+# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert
+class SqueezeBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert.py b/src/transformers/models/squeezebert/tokenization_squeezebert.py
index e41e576455fe..72d927eccafb 100644
--- a/src/transformers/models/squeezebert/tokenization_squeezebert.py
+++ b/src/transformers/models/squeezebert/tokenization_squeezebert.py
@@ -24,9 +24,13 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py b/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
index 58708030f9f3..5ee656e5a8d5 100644
--- a/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
+++ b/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
@@ -25,14 +25,24 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/tokenizer.json",
- "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/tokenizer.json",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/tokenizer.json",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/tokenizer.json"
+ ),
+ "squeezebert/squeezebert-mnli": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/tokenizer.json"
+ ),
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/swin/__init__.py b/src/transformers/models/swin/__init__.py
index b8cb65d08b3a..33a9bddeea73 100644
--- a/src/transformers/models/swin/__init__.py
+++ b/src/transformers/models/swin/__init__.py
@@ -18,15 +18,18 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
-_import_structure = {
- "configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"],
-}
+_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_swin"] = [
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinForImageClassification",
@@ -35,11 +38,29 @@
"SwinPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_swin"] = [
+ "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSwinForImageClassification",
+ "TFSwinForMaskedImageModeling",
+ "TFSwinModel",
+ "TFSwinPreTrainedModel",
+ ]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinForImageClassification,
@@ -48,6 +69,19 @@
SwinPreTrainedModel,
)
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinPreTrainedModel,
+ )
else:
import sys
diff --git a/src/transformers/models/swin/configuration_swin.py b/src/transformers/models/swin/configuration_swin.py
index 9956482b9ab7..878a73e9208b 100644
--- a/src/transformers/models/swin/configuration_swin.py
+++ b/src/transformers/models/swin/configuration_swin.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/swin-tiny-patch4-window7-224": "https://huggingface.co/microsoft/swin-tiny-patch4-window7-224/resolve/main/config.json",
+ "microsoft/swin-tiny-patch4-window7-224": (
+ "https://huggingface.co/microsoft/swin-tiny-patch4-window7-224/resolve/main/config.json"
+ ),
# See all Swin models at https://huggingface.co/models?filter=swin
}
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index 727399f17f4d..48c9b8cccf9e 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -59,7 +59,7 @@
# See all Swin models at https://huggingface.co/models?filter=swin
]
-# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
+# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
@dataclass
@@ -103,7 +103,7 @@ class SwinModelOutput(ModelOutput):
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
Average pooling of the last layer hidden-state.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
@@ -125,7 +125,7 @@ class SwinModelOutput(ModelOutput):
"""
last_hidden_state: torch.FloatTensor = None
- pooler_output: torch.FloatTensor = None
+ pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@@ -203,13 +203,6 @@ class SwinImageClassifierOutput(ModelOutput):
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
-# Copied from transformers.models.vit.modeling_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
@@ -226,26 +219,12 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
- batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return input * random_tensor
-
-
class SwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
@@ -254,12 +233,7 @@ class SwinEmbeddings(nn.Module):
def __init__(self, config, use_mask_token=False):
super().__init__()
- self.patch_embeddings = SwinPatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.embed_dim,
- )
+ self.patch_embeddings = SwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
@@ -295,20 +269,25 @@ def forward(
class SwinPatchEmbeddings(nn.Module):
"""
- Image to Patch Embedding.
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
- def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
@@ -320,7 +299,11 @@ def maybe_pad(self, pixel_values, height, width):
return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
- _, _, height, width = pixel_values.shape
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
@@ -385,16 +368,40 @@ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
return input_feature
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
class SwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None, scale_by_keep=True):
- super(SwinDropPath, self).__init__()
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
- def forward(self, input):
- return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
class SwinSelfAttention(nn.Module):
@@ -402,13 +409,16 @@ def __init__(self, config, dim, num_heads):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
- f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})"
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.window_size = to_2tuple(config.window_size)
+ window_size = config.window_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
@@ -435,7 +445,7 @@ def __init__(self, config, dim, num_heads):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -488,7 +498,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -997,7 +1007,8 @@ def forward(
@add_start_docstrings(
- "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM `__.",
+ "Swin Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://arxiv.org/abs/2111.09886).",
SWIN_START_DOCSTRING,
)
class SwinForMaskedImageModeling(SwinPreTrainedModel):
@@ -1008,7 +1019,9 @@ def __init__(self, config):
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
self.decoder = nn.Sequential(
- nn.Conv2d(in_channels=num_features, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
+ nn.Conv2d(
+ in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
+ ),
nn.PixelShuffle(config.encoder_stride),
)
@@ -1067,11 +1080,10 @@ def forward(
)
sequence_output = outputs[0]
-
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
- height = width = int(sequence_length**0.5)
+ height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py
new file mode 100644
index 000000000000..2f9bd27b0e00
--- /dev/null
+++ b/src/transformers/models/swin/modeling_tf_swin.py
@@ -0,0 +1,1490 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 Swin Transformer model."""
+
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import ACT2FN
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_swin import SwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "SwinConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/swin-tiny-patch4-window7-224",
+ # See all Swin models at https://huggingface.co/models?filter=swin
+]
+
+# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow
+# implementations of PyTorch functionalities in the timm library.
+
+
+@dataclass
+class TFSwinEncoderOutput(ModelOutput):
+ """
+ Swin encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinModelOutput(ModelOutput):
+ """
+ Swin model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: tf.Tensor = None
+ pooler_output: Optional[tf.Tensor] = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinMaskedImageModelingOutput(ModelOutput):
+ """
+ Swin masked image model outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Masked image modeling (MLM) loss.
+ logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed pixel values.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[tf.Tensor] = None
+ logits: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinImageClassifierOutput(ModelOutput):
+ """
+ Swin outputs for image classification.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[tf.Tensor] = None
+ logits: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = shape_list(input_feature)
+ input_feature = tf.reshape(
+ input_feature,
+ (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels),
+ )
+ windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5))
+ windows = tf.reshape(windows, (-1, window_size, window_size, num_channels))
+ return windows
+
+
+def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor:
+ """
+ Merges windows to produce higher resolution features.
+ """
+ x = tf.shape(windows)[0]
+ y = tf.cast(height * width / (window_size * window_size), tf.int32)
+ batch_size = tf.math.floordiv(x, y)
+ windows = tf.reshape(
+ windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
+ )
+ windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5))
+ windows = tf.reshape(windows, (batch_size, height, width, -1))
+ return windows
+
+
+def drop_path(
+ input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+) -> tf.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ input_shape = shape_list(input)
+ ndim = len(input_shape)
+ shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = tf.random.uniform(shape)
+ random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor /= keep_prob
+ return input * random_tensor
+
+
+class TFSwinEmbeddings(tf.keras.layers.Layer):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings")
+ self.num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.embed_dim = config.embed_dim
+ self.use_mask_token = use_mask_token
+ self.use_absolute_embeddings = config.use_absolute_embeddings
+
+ self.norm = tf.keras.layers.LayerNormalization(name="norm", epsilon=1e-5)
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
+
+ def build(self, input_shape: tf.TensorShape) -> None:
+ if self.use_mask_token:
+ self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token")
+ else:
+ self.mask_token = None
+
+ if self.use_absolute_embeddings:
+ self.position_embeddings = self.add_weight(
+ (1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings"
+ )
+ else:
+ self.position_embeddings = None
+ super().build(input_shape)
+
+ def call(
+ self, pixel_values: tf.Tensor, bool_masked_pos: bool = None, training: bool = False
+ ) -> Tuple[tf.Tensor, Tuple[int, int]]:
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
+ embeddings = self.norm(embeddings, training=training)
+ batch_size, seq_len, _ = shape_list(embeddings)
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
+ mask_tokens = tf.repeat(mask_tokens, seq_len, 1)
+ # replace the masked visual tokens by mask_tokens
+ mask = tf.expand_dims(bool_masked_pos, -1)
+ mask = tf.cast(mask, mask_tokens.dtype)
+
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings, training=training)
+
+ return embeddings, output_dimensions
+
+
+class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = tf.keras.layers.Conv2D(
+ filters=hidden_size,
+ kernel_size=self.patch_size,
+ strides=self.patch_size,
+ padding="valid",
+ name="projection",
+ )
+
+ def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ if width % self.patch_size[1] != 0:
+ pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1]))
+ pixel_values = tf.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0))
+ pixel_values = tf.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
+ _, num_channels, height, width = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+
+ # B,C,H,W -> B,H,W,C
+ pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
+
+ embeddings = self.projection(pixel_values, training=training)
+
+ # B,H,W,C -> B,C,H,W
+ embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
+
+ batch_size, channels, height, width = shape_list(embeddings)
+ output_dimensions = (height, width)
+
+ embeddings = tf.reshape(embeddings, (batch_size, channels, -1))
+ embeddings = tf.transpose(embeddings, (0, 2, 1))
+ return embeddings, output_dimensions
+
+
+class TFSwinPatchMerging(tf.keras.layers.Layer):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`tf.keras.layer.Layer`, *optional*, defaults to `tf.keras.layers.LayerNormalization`):
+ Normalization layer class.
+ """
+
+ def __init__(
+ self, input_resolution: Tuple[int, int], dim: int, norm_layer: Optional[Callable] = None, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = tf.keras.layers.Dense(2 * dim, use_bias=False, name="reduction")
+ if norm_layer is None:
+ # Use same default epsilon as PyTorch
+ self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="norm")
+ else:
+ self.norm = norm_layer(name="norm")
+
+ def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0))
+ input_feature = tf.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, _, num_channels = shape_list(input_feature)
+
+ input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels))
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = tf.reshape(
+ input_feature, (batch_size, -1, 4 * num_channels)
+ ) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature, training=training)
+ input_feature = self.reduction(input_feature, training=training)
+
+ return input_feature
+
+
+class TFSwinDropPath(tf.keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = None, scale_by_keep: bool = True, **kwargs) -> None:
+ super(TFSwinDropPath, self).__init__(**kwargs)
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor:
+ return drop_path(input, self.drop_prob, training, self.scale_by_keep)
+
+
+class TFSwinSelfAttention(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ window_size = config.window_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+
+ self.query = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="query",
+ )
+ self.key = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="key",
+ )
+ self.value = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="value",
+ )
+
+ self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
+
+ def build(self, input_shape: tf.TensorShape) -> None:
+ self.relative_position_bias_table = self.add_weight(
+ shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads),
+ initializer="zeros",
+ name="relative_position_bias_table",
+ )
+ self.relative_position_index = self.add_weight(
+ shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
+ trainable=False,
+ dtype=tf.int32,
+ name="relative_position_index",
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = tf.range(self.window_size[0])
+ coords_w = tf.range(self.window_size[1])
+ coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
+ coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = tf.transpose(relative_coords, (1, 2, 0))
+
+ stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
+ stack_0 += self.window_size[0] - 1
+ stack_0 *= 2 * self.window_size[1] - 1
+ stack_1 += self.window_size[1] - 1
+ relative_coords = tf.stack([stack_0, stack_1], axis=2)
+
+ self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
+ super().build(input_shape)
+
+ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+ new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor, ...]:
+ batch_size, dim, _ = shape_list(hidden_states)
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2)))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ relative_position_bias = tf.gather(
+ self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,))
+ )
+ relative_position_bias = tf.reshape(
+ relative_position_bias,
+ (self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1),
+ )
+
+ relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))
+ attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in SwinModel call() function)
+ mask_shape = shape_list(attention_mask)[0]
+ attention_scores = tf.reshape(
+ attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
+ )
+ attention_mask = tf.expand_dims(attention_mask, 1)
+ attention_mask = tf.expand_dims(attention_mask, 0)
+ attention_scores = attention_scores + attention_mask
+ attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim))
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = tf.nn.softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+ context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
+ new_context_layer_shape = shape_list(context_layer)[:-2] + [
+ self.all_head_size,
+ ]
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class TFSwinSelfOutput(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(dim, name="dense")
+ self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout")
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSwinAttention(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.self = TFSwinSelfAttention(config, dim, num_heads, name="self")
+ self.self_output = TFSwinSelfOutput(config, dim, name="output")
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ """
+ Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in
+ this layer}
+ """
+ raise NotImplementedError
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training)
+ attention_output = self.self_output(self_outputs[0], hidden_states, training=training)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class TFSwinIntermediate(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(int(config.mlp_ratio * dim), name="dense")
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class TFSwinOutput(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(dim, name="dense")
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, "dropout")
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSwinLayer(tf.keras.layers.Layer):
+ def __init__(
+ self, config, dim, input_resolution: Tuple[int, int], num_heads: int, shift_size: int = 0, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ min_res = tf.reduce_min(input_resolution)
+ self.window_size = min_res if min_res <= config.window_size else config.window_size
+ self.shift_size = 0 if min_res <= self.window_size else shift_size
+ self.input_resolution = input_resolution
+
+ self.layernorm_before = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_before"
+ )
+ self.attention = TFSwinAttention(config, dim, num_heads, name="attention")
+ self.drop_path = (
+ TFSwinDropPath(config.drop_path_rate, name="drop_path")
+ if config.drop_path_rate > 0.0
+ else tf.keras.layers.Activation("linear", name="drop_path")
+ )
+ self.layernorm_after = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_after"
+ )
+ self.intermediate = TFSwinIntermediate(config, dim, name="intermediate")
+ self.swin_output = TFSwinOutput(config, dim, name="output")
+
+ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> Optional[tf.Tensor]:
+ img_mask = tf.zeros((height, width))
+ height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
+ width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
+
+ # calculate attention mask for SW-MSA
+ if shift_size > 0:
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1)
+ width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1)
+ indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2))
+ if len(indices) >= 1:
+ updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count
+ img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates)
+ count += 1
+
+ img_mask = tf.expand_dims(img_mask, -1)
+ img_mask = tf.expand_dims(img_mask, 0)
+
+ mask_windows = window_partition(img_mask, window_size)
+ mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
+ attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
+ attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
+ attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
+ return attn_mask
+
+ def maybe_pad(
+ self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
+ pad_right = (window_size - width % window_size) % window_size
+ pad_bottom = (window_size - height % window_size) % window_size
+ pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
+ hidden_states = tf.pad(hidden_states, pad_values)
+ pad_values = tf.reshape(pad_values, (-1,))
+ return hidden_states, pad_values
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor:
+ # if window size is larger than input resolution, we don't partition windows
+ min_res = tf.reduce_min(input_dimensions)
+ shift_size = 0 if min_res <= self.window_size else self.shift_size
+ window_size = min_res if min_res <= self.window_size else self.window_size
+
+ height, width = input_dimensions
+ batch_size, _, channels = shape_list(hidden_states)
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states, training=training)
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width)
+
+ _, height_pad, width_pad, _ = shape_list(hidden_states)
+ # cyclic shift
+ if shift_size > 0:
+ shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, window_size)
+ hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels))
+ attn_mask = self.get_attn_mask(
+ height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size
+ )
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels))
+ shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if shift_size > 0:
+ attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :]
+
+ attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels))
+
+ hidden_states = shortcut + self.drop_path(attention_windows, training=training)
+
+ layer_output = self.layernorm_after(hidden_states, training=training)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.swin_output(layer_output, training=training)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class TFSwinStage(tf.keras.layers.Layer):
+ def __init__(
+ self,
+ config: SwinConfig,
+ dim: int,
+ input_resolution: Tuple[int, int],
+ depth: int,
+ num_heads: int,
+ drop_path: List[float],
+ downsample: Optional[Callable],
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.dim = dim
+ self.blocks = [
+ TFSwinLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ name=f"blocks.{i}",
+ )
+ for i in range(depth)
+ ]
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ input_resolution,
+ dim=dim,
+ norm_layer=partial(tf.keras.layers.LayerNormalization, epsilon=1e-5),
+ name="downsample",
+ )
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor, ...]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class TFSwinEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, grid_size: Tuple[int, int], **kwargs):
+ super().__init__(**kwargs)
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy())
+ self.layers = [
+ TFSwinStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+ name=f"layers.{i_layer}",
+ )
+ for i_layer in range(self.num_layers)
+ ]
+
+ self.gradient_checkpointing = False
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor, ...], TFSwinEncoderOutput]:
+ all_input_dimensions = ()
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = shape_list(hidden_states)
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
+ reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
+ )
+
+ hidden_states = layer_outputs[0]
+ output_dimensions = layer_outputs[1]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+ all_input_dimensions += (input_dimensions,)
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = shape_list(hidden_states)
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
+ reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TFSwinEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+class TFSwinPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SwinConfig
+ base_model_prefix = "swin"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _set_gradient_checkpointing(self, module, value=False) -> None:
+ if isinstance(module, TFSwinEncoder):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network. Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(
+ shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size),
+ dtype=tf.float32,
+ )
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+ return self.serving_output(output)
+
+
+SWIN_START_DOCSTRING = r"""
+ This model is a Tensorflow
+ [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
+ regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`SwinConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def normalize_data_format(value: str) -> str:
+ """
+ From tensorflow addons
+ https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py#L71
+ """
+ if value is None:
+ value = tf.keras.backend.image_data_format()
+ data_format = value.lower()
+ if data_format not in {"channels_first", "channels_last"}:
+ raise ValueError(
+ 'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value)
+ )
+ return data_format
+
+
+class AdaptiveAveragePooling1D(tf.keras.layers.Layer):
+ """
+ Args:
+ Average 1D Pooling with adaptive kernel size.
+ output_size: An integer or tuple/list of a single integer, specifying pooled_features.
+ The new size of output channels.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds
+ to inputs with shape `(batch, channels, steps)`.
+ Input shape:
+ - If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`.
+ - If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`.
+ Output shape:
+ - If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`.
+ - If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`.
+
+ Adapted from [tensorflow-addon's adaptive pooling.py](
+ https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120
+ )
+ """
+
+ def __init__(
+ self,
+ output_size: Union[int, Iterable[int]],
+ reduce_function: Callable = tf.reduce_mean,
+ data_format: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ self.data_format = normalize_data_format(data_format)
+ self.reduce_function = reduce_function
+ self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size)
+ super().__init__(**kwargs)
+
+ def call(self, inputs: tf.Tensor, *args) -> None:
+ bins = self.output_size[0]
+ if self.data_format == "channels_last":
+ splits = tf.split(inputs, bins, axis=1)
+ splits = tf.stack(splits, axis=1)
+ out_vect = self.reduce_function(splits, axis=2)
+ else:
+ splits = tf.split(inputs, bins, axis=2)
+ splits = tf.stack(splits, axis=2)
+ out_vect = self.reduce_function(splits, axis=3)
+ return out_vect
+
+ def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape:
+ input_shape = tf.TensorShape(input_shape).as_list()
+ if self.data_format == "channels_last":
+ shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]])
+ else:
+ shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]])
+ return shape
+
+ def get_config(self) -> Dict[str, Any]:
+ config = {
+ "output_size": self.output_size,
+ "data_format": self.data_format,
+ }
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+
+@keras_serializable
+class TFSwinMainLayer(tf.keras.layers.Layer):
+ config_class = SwinConfig
+
+ def __init__(
+ self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
+ self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder")
+
+ self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> TFSwinPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List]):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_head_mask(self, head_mask: Optional[Any]) -> List:
+ if head_mask is not None:
+ raise NotImplementedError
+ return [None] * len(self.config.depths)
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, training=training
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output, training=training)
+
+ pooled_output = None
+ if self.pooler is not None:
+ batch_size, _, num_features = shape_list(sequence_output)
+ pooled_output = self.pooler(sequence_output)
+ pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+ return output
+
+ return TFSwinModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
+ SWIN_START_DOCSTRING,
+)
+class TFSwinModel(TFSwinPreTrainedModel):
+ def __init__(
+ self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.swin = TFSwinMainLayer(config, name="swin")
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSwinModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ swin_outputs = self.swin(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return swin_outputs
+
+ def serving_output(self, output: TFSwinModelOutput) -> TFSwinModelOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSwinModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ pooler_output=output.pooler_output,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ reshaped_hidden_states=output.reshaped_hidden_states,
+ )
+
+
+class TFSwinPixelShuffle(tf.keras.layers.Layer):
+ """TF layer implementation of torch.nn.PixelShuffle"""
+
+ def __init__(self, upscale_factor: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if not isinstance(upscale_factor, int) or upscale_factor < 2:
+ raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
+ self.upscale_factor = upscale_factor
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ hidden_states = x
+ batch_size, _, _, num_input_channels = shape_list(hidden_states)
+ block_size_squared = self.upscale_factor**2
+ output_depth = int(num_input_channels / block_size_squared)
+ # When the number of output channels >= 2, PyTorch's PixelShuffle and
+ # TF's depth_to_space differ in their output as the order of channels selected for combining
+ # is a permutation of the other c.f.
+ # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
+ permutation = tf.constant(
+ [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
+ )
+ hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
+ hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
+ return hidden_states
+
+
+class TFSwinDecoder(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.conv2d = tf.keras.layers.Conv2D(
+ filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
+ )
+ self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1")
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ hidden_states = x
+ # B,C,H,W -> B,H,W,C
+ hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
+ hidden_states = self.conv2d(hidden_states)
+ hidden_states = self.pixel_shuffle(hidden_states)
+ # B,H,W,C -> B,C,H,W
+ hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2))
+ return hidden_states
+
+
+@add_start_docstrings(
+ "Swin Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://arxiv.org/abs/2111.09886).",
+ SWIN_START_DOCSTRING,
+)
+class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
+ def __init__(self, config: SwinConfig):
+ super().__init__(config)
+
+ self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin")
+
+ self.decoder = TFSwinDecoder(config, name="decoder")
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple, TFSwinMaskedImageModelingOutput]:
+ r"""
+ bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoFeatureExtractor, TFSwinForMaskedImageModeling
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+ >>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = feature_extractor(images=image, return_tensors="tf").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swin(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = tf.transpose(sequence_output, (0, 2, 1))
+ batch_size, num_channels, sequence_length = shape_list(sequence_output)
+ height = width = int(sequence_length**0.5)
+ sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width))
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size))
+ mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1)
+ mask = tf.repeat(mask, self.config.patch_size, 2)
+ mask = tf.expand_dims(mask, 1)
+ mask = tf.cast(mask, tf.float32)
+
+ reconstruction_loss = tf.keras.losses.mean_absolute_error(
+ # Swap axes as metric calculation reduces over the final dimension
+ tf.transpose(pixel_values, (1, 2, 3, 0)),
+ tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)),
+ )
+ reconstruction_loss = tf.expand_dims(reconstruction_loss, 0)
+ total_loss = tf.reduce_sum(reconstruction_loss * mask)
+ num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels
+ masked_im_loss = total_loss / num_masked_pixels
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[2:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return TFSwinMaskedImageModelingOutput(
+ loss=masked_im_loss,
+ logits=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+ def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSwinMaskedImageModelingOutput(
+ logits=output.logits,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ reshaped_hidden_states=output.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ SWIN_START_DOCSTRING,
+)
+class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: SwinConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.swin = TFSwinMainLayer(config, name="swin")
+
+ # Classifier head
+ self.classifier = (
+ tf.keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else tf.keras.layers.Activation("linear", name="classifier")
+ )
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSwinImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor, ...], TFSwinImageClassifierOutput]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swin(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output, training=training)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSwinImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+ def serving_output(self, output: TFSwinImageClassifierOutput) -> TFSwinImageClassifierOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
+ return TFSwinImageClassifierOutput(
+ logits=output.logits,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ reshaped_hidden_states=output.reshaped_hidden_states,
+ )
diff --git a/src/transformers/models/swinv2/__init__.py b/src/transformers/models/swinv2/__init__.py
new file mode 100644
index 000000000000..1cf259b8303e
--- /dev/null
+++ b/src/transformers/models/swinv2/__init__.py
@@ -0,0 +1,65 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_swinv2"] = [
+ "SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Swinv2ForImageClassification",
+ "Swinv2ForMaskedImageModeling",
+ "Swinv2Model",
+ "Swinv2PreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_swinv2 import (
+ SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Swinv2ForImageClassification,
+ Swinv2ForMaskedImageModeling,
+ Swinv2Model,
+ Swinv2PreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/swinv2/configuration_swinv2.py b/src/transformers/models/swinv2/configuration_swinv2.py
new file mode 100644
index 000000000000..f861be05fe1f
--- /dev/null
+++ b/src/transformers/models/swinv2/configuration_swinv2.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Swinv2 Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/swinv2_tiny_patch4_windows8_256": (
+ "https://huggingface.co/microsoft/swinv2_tiny_patch4_windows8_256/resolve/main/config.json"
+ ),
+}
+
+
+class Swinv2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin
+ Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2
+ [microsoft/swinv2_tiny_patch4_windows8_256](https://huggingface.co/microsoft/swinv2_tiny_patch4_windows8_256)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 7):
+ Size of windows.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to add absolute position embeddings to the patch embeddings.
+ patch_norm (`bool`, *optional*, defaults to `True`):
+ Whether or not to add layer normalization after patch embedding.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ encoder_stride (`int`, `optional`, defaults to 32):
+ Factor to increase the spatial resolution by in the decoder head for masked image modeling.
+
+ Example:
+
+ ```python
+ >>> from transformers import Swinv2Config, Swinv2Model
+
+ >>> # Initializing a Swinv2 microsoft/swinv2_tiny_patch4_windows8_256 style configuration
+ >>> configuration = Swinv2Config()
+
+ >>> # Initializing a model from the microsoft/swinv2_tiny_patch4_windows8_256 style configuration
+ >>> model = Swinv2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "swinv2"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ encoder_stride=32,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.path_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.encoder_stride = encoder_stride
+ # we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+ self.pretrained_window_sizes = (0, 0, 0, 0)
diff --git a/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py b/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
new file mode 100644
index 000000000000..148793e3043b
--- /dev/null
+++ b/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Swinv2 checkpoints from the timm library."""
+
+import argparse
+import json
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+import requests
+import timm
+from huggingface_hub import hf_hub_download
+from transformers import AutoFeatureExtractor, Swinv2Config, Swinv2ForImageClassification
+
+
+def get_swinv2_config(swinv2_name):
+ config = Swinv2Config()
+ name_split = swinv2_name.split("_")
+
+ model_size = name_split[1]
+ if "to" in name_split[3]:
+ img_size = int(name_split[3][-3:])
+ else:
+ img_size = int(name_split[3])
+ if "to" in name_split[2]:
+ window_size = int(name_split[2][-2:])
+ else:
+ window_size = int(name_split[2][6:])
+
+ if model_size == "tiny":
+ embed_dim = 96
+ depths = (2, 2, 6, 2)
+ num_heads = (3, 6, 12, 24)
+ elif model_size == "small":
+ embed_dim = 96
+ depths = (2, 2, 18, 2)
+ num_heads = (3, 6, 12, 24)
+ elif model_size == "base":
+ embed_dim = 128
+ depths = (2, 2, 18, 2)
+ num_heads = (4, 8, 16, 32)
+ else:
+ embed_dim = 192
+ depths = (2, 2, 18, 2)
+ num_heads = (6, 12, 24, 48)
+
+ if "to" in swinv2_name:
+ config.pretrained_window_sizes = (12, 12, 12, 6)
+
+ if ("22k" in swinv2_name) and ("to" not in swinv2_name):
+ num_classes = 21841
+ repo_id = "datasets/huggingface/label-files"
+ filename = "imagenet-22k-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ else:
+ num_classes = 1000
+ repo_id = "datasets/huggingface/label-files"
+ filename = "imagenet-1k-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ config.image_size = img_size
+ config.num_labels = num_classes
+ config.embed_dim = embed_dim
+ config.depths = depths
+ config.num_heads = num_heads
+ config.window_size = window_size
+
+ return config
+
+
+def rename_key(name):
+ if "patch_embed.proj" in name:
+ name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
+ if "patch_embed.norm" in name:
+ name = name.replace("patch_embed.norm", "embeddings.norm")
+ if "layers" in name:
+ name = "encoder." + name
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "attn" in name:
+ name = name.replace("attn", "attention.self")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+ if "q_bias" in name:
+ name = name.replace("q_bias", "query.bias")
+ if "k_bias" in name:
+ name = name.replace("k_bias", "key.bias")
+ if "v_bias" in name:
+ name = name.replace("v_bias", "value.bias")
+ if "cpb_mlp" in name:
+ name = name.replace("cpb_mlp", "continuous_position_bias_mlp")
+ if name == "norm.weight":
+ name = "layernorm.weight"
+ if name == "norm.bias":
+ name = "layernorm.bias"
+
+ if "head" in name:
+ name = name.replace("head", "classifier")
+ else:
+ name = "swinv2." + name
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, model):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if "mask" in key:
+ continue
+ elif "qkv" in key:
+ key_split = key.split(".")
+ layer_num = int(key_split[1])
+ block_num = int(key_split[3])
+ dim = model.swinv2.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
+
+ if "weight" in key:
+ orig_state_dict[
+ f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
+ ] = val[:dim, :]
+ orig_state_dict[
+ f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"
+ ] = val[dim : dim * 2, :]
+ orig_state_dict[
+ f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
+ ] = val[-dim:, :]
+ else:
+ orig_state_dict[
+ f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"
+ ] = val[:dim]
+ orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[
+ dim : dim * 2
+ ]
+ orig_state_dict[
+ f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"
+ ] = val[-dim:]
+ else:
+ orig_state_dict[rename_key(key)] = val
+
+ return orig_state_dict
+
+
+def convert_swinv2_checkpoint(swinv2_name, pytorch_dump_folder_path):
+ timm_model = timm.create_model(swinv2_name, pretrained=True)
+ timm_model.eval()
+
+ config = get_swinv2_config(swinv2_name)
+ model = Swinv2ForImageClassification(config)
+ model.eval()
+
+ new_state_dict = convert_state_dict(timm_model.state_dict(), model)
+ model.load_state_dict(new_state_dict)
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/{}".format(swinv2_name.replace("_", "-")))
+ image = Image.open(requests.get(url, stream=True).raw)
+ inputs = feature_extractor(images=image, return_tensors="pt")
+
+ timm_outs = timm_model(inputs["pixel_values"])
+ hf_outs = model(**inputs).logits
+
+ assert torch.allclose(timm_outs, hf_outs, atol=1e-3)
+
+ print(f"Saving model {swinv2_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ print(f"Saving feature extractor to {pytorch_dump_folder_path}")
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+
+ model.push_to_hub(
+ repo_path_or_name=Path(pytorch_dump_folder_path, swinv2_name),
+ organization="nandwalritik",
+ commit_message="Add model",
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--swinv2_name",
+ default="swinv2_tiny_patch4_window8_256",
+ type=str,
+ help="Name of the Swinv2 timm model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+
+ args = parser.parse_args()
+ convert_swinv2_checkpoint(args.swinv2_name, args.pytorch_dump_folder_path)
diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py
new file mode 100644
index 000000000000..52f836d5b91d
--- /dev/null
+++ b/src/transformers/models/swinv2/modeling_swinv2.py
@@ -0,0 +1,1292 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Swinv2 Transformer model."""
+
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_swinv2 import Swinv2Config
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Swinv2Config"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/swinv2-tiny-patch4-window8-256"
+_EXPECTED_OUTPUT_SHAPE = [1, 64, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "microsoft/swinv2-tiny-patch4-window8-256"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/swinv2-tiny-patch4-window8-256",
+ # See all Swinv2 models at https://huggingface.co/models?filter=swinv2
+]
+
+
+# drop_path, Swinv2PatchEmbeddings, Swinv2PatchMerging and Swinv2DropPath are from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2.py.
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2
+class Swinv2EncoderOutput(ModelOutput):
+ """
+ Swinv2 encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->Swinv2
+class Swinv2ModelOutput(ModelOutput):
+ """
+ Swinv2 model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2
+class Swinv2MaskedImageModelingOutput(ModelOutput):
+ """
+ Swinv2 masked image model outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Masked image modeling (MLM) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed pixel values.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
+class Swinv2ImageClassifierOutput(ModelOutput):
+ """
+ Swinv2 outputs for image classification.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = input_feature.shape
+ input_feature = input_feature.view(
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+ )
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+ """
+ Merges windows to produce higher resolution features.
+ """
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
+ windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.drop_path
+def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swinv2
+class Swinv2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->Swinv2
+class Swinv2Embeddings(nn.Module):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config, use_mask_token=False):
+ super().__init__()
+
+ self.patch_embeddings = Swinv2PatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+ else:
+ self.position_embeddings = None
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> Tuple[torch.Tensor]:
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+ batch_size, seq_len, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->Swinv2
+class Swinv2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def maybe_pad(self, pixel_values, height, width):
+ if width % self.patch_size[1] != 0:
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+ embeddings = self.projection(pixel_values)
+ _, _, height, width = embeddings.shape
+ output_dimensions = (height, width)
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, output_dimensions
+
+
+class Swinv2PatchMerging(nn.Module):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def maybe_pad(self, input_feature, height, width):
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
+ input_feature = nn.functional.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, dim, num_channels = input_feature.shape
+
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # [batch_size, height/2 * width/2, 4*num_channels]
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C]
+
+ input_feature = self.reduction(input_feature)
+ input_feature = self.norm(input_feature)
+
+ return input_feature
+
+
+class Swinv2SelfAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+ self.pretrained_window_size = pretrained_window_size
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ # mlp to generate continuous relative position bias
+ self.continuous_position_bias_mlp = nn.Sequential(
+ nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
+ )
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = (
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
+ .permute(1, 2, 0)
+ .contiguous()
+ .unsqueeze(0)
+ ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
+ else:
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = (
+ torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
+ )
+ self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer("relative_position_index", relative_position_index, persistent=False)
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ batch_size, dim, num_channels = hidden_states.shape
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # cosine attention
+ attention_scores = F.normalize(query_layer, dim=-1) @ F.normalize(key_layer, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
+ attention_scores = attention_scores * logit_scale
+ relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
+ -1, self.num_attention_heads
+ )
+ # [window_height*window_width,window_height*window_width,num_attention_heads]
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ )
+ # [num_attention_heads,window_height*window_width,window_height*window_width]
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in Swinv2Model forward() function)
+ mask_shape = attention_mask.shape[0]
+ attention_scores = attention_scores.view(
+ batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+ ) + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2
+class Swinv2SelfOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class Swinv2Attention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
+ super().__init__()
+ self.self = Swinv2SelfAttention(
+ config=config,
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ pretrained_window_size=pretrained_window_size
+ if isinstance(pretrained_window_size, collections.abc.Iterable)
+ else (pretrained_window_size, pretrained_window_size),
+ )
+ self.output = Swinv2SelfOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swinv2
+class Swinv2Intermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swinv2
+class Swinv2Output(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class Swinv2Layer(nn.Module):
+ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.set_shift_and_window_size(input_resolution)
+ self.attention = Swinv2Attention(
+ config=config,
+ dim=dim,
+ num_heads=num_heads,
+ window_size=self.window_size,
+ pretrained_window_size=pretrained_window_size
+ if isinstance(pretrained_window_size, collections.abc.Iterable)
+ else (pretrained_window_size, pretrained_window_size),
+ )
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.drop_path = Swinv2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.intermediate = Swinv2Intermediate(config, dim)
+ self.output = Swinv2Output(config, dim)
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+
+ def set_shift_and_window_size(self, input_resolution):
+ target_window_size = (
+ self.window_size
+ if isinstance(self.window_size, collections.abc.Iterable)
+ else (self.window_size, self.window_size)
+ )
+ target_shift_size = (
+ self.shift_size
+ if isinstance(self.shift_size, collections.abc.Iterable)
+ else (self.shift_size, self.shift_size)
+ )
+ self.window_size = (
+ input_resolution[0] if input_resolution[0] <= target_window_size[0] else target_window_size[0]
+ )
+ self.shift_size = (
+ 0
+ if input_resolution
+ <= (
+ self.window_size
+ if isinstance(self.window_size, collections.abc.Iterable)
+ else (self.window_size, self.window_size)
+ )
+ else target_shift_size[0]
+ )
+
+ def get_attn_mask(self, height, width):
+ if self.shift_size > 0:
+ # calculate attention mask for shifted window multihead self attention
+ img_mask = torch.zeros((1, height, width, 1))
+ height_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ width_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ img_mask[:, height_slice, width_slice, :] = count
+ count += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states, height, width):
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ self.set_shift_and_window_size(input_dimensions)
+ height, width = input_dimensions
+ batch_size, _, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ # pad hidden_states to multiples of window size
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+ _, height_pad, width_pad, _ = hidden_states.shape
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+ attn_mask = self.get_attn_mask(height_pad, width_pad)
+ if attn_mask is not None:
+ attn_mask = attn_mask.to(hidden_states_windows.device)
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
+ hidden_states = self.layernorm_before(attention_windows)
+ hidden_states = shortcut + self.drop_path(hidden_states)
+
+ layer_output = self.intermediate(hidden_states)
+ layer_output = self.output(layer_output)
+ layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class Swinv2Stage(nn.Module):
+ def __init__(
+ self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0
+ ):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.blocks = nn.ModuleList(
+ [
+ Swinv2Layer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ pretrained_window_size=pretrained_window_size,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ # Copied from transformers.models.swin.modeling_swin.SwinStage.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(layer_outputs[0], input_dimensions)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class Swinv2Encoder(nn.Module):
+ def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
+ super().__init__()
+ self.num_layers = len(config.depths)
+ self.config = config
+ if self.config.pretrained_window_sizes is not None:
+ pretrained_window_sizes = config.pretrained_window_sizes
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
+ self.layers = nn.ModuleList(
+ [
+ Swinv2Stage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
+ pretrained_window_size=pretrained_window_sizes[i_layer],
+ )
+ for i_layer in range(self.num_layers)
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ # Copied from transformers.models.swin.modeling_swin.SwinEncoder.forward with SwinEncoderOutput->Swinv2EncoderOutput
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, Swinv2EncoderOutput]:
+ all_input_dimensions = ()
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ output_dimensions = layer_outputs[1]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+ all_input_dimensions += (input_dimensions,)
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return Swinv2EncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->Swinv2,swin->swinv2
+class Swinv2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Swinv2Config
+ base_model_prefix = "swinv2"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, Swinv2Encoder):
+ module.gradient_checkpointing = value
+
+
+SWINV2_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Swinv2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWINV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.",
+ SWINV2_START_DOCSTRING,
+)
+# Copied from transformers.models.swin.modeling_swin.SwinModel with SWIN->SWINV2,Swin->Swinv2
+class Swinv2Model(Swinv2PreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+ super().__init__(config)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token)
+ self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
+
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Swinv2ModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Swinv2ModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return Swinv2ModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "Swinv2 Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://arxiv.org/abs/2111.09886).",
+ SWINV2_START_DOCSTRING,
+)
+# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with SWIN->SWINV2,Swin->Swinv2,swin->swinv2,224->256,window7->window8
+class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True)
+
+ num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
+ self.decoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
+ ),
+ nn.PixelShuffle(config.encoder_stride),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Swinv2MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoFeatureExtractor, Swinv2ForMaskedImageModeling
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
+ >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 256, 256]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swinv2(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output.transpose(1, 2)
+ batch_size, num_channels, sequence_length = sequence_output.shape
+ height = width = math.floor(sequence_length**0.5)
+ sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+ mask = (
+ bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
+ .repeat_interleave(self.config.patch_size, 2)
+ .unsqueeze(1)
+ .contiguous()
+ )
+ reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
+ masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[2:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return Swinv2MaskedImageModelingOutput(
+ loss=masked_im_loss,
+ logits=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ SWINV2_START_DOCSTRING,
+)
+# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with SWIN->SWINV2,Swin->Swinv2,swin->swinv2
+class Swinv2ForImageClassification(Swinv2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.swinv2 = Swinv2Model(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=Swinv2ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Swinv2ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swinv2(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Swinv2ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py
index 9ccb94932843..2f0bd9521ac2 100644
--- a/src/transformers/models/t5/__init__.py
+++ b/src/transformers/models/t5/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -28,17 +29,30 @@
)
-_import_structure = {
- "configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"],
-}
+_import_structure = {"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_t5"] = ["T5Tokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_t5_fast"] = ["T5TokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_t5"] = [
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"T5EncoderModel",
@@ -48,7 +62,12 @@
"load_tf_weights_in_t5",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_t5"] = [
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFT5EncoderModel",
@@ -57,8 +76,14 @@
"TFT5PreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_t5"] = [
+ "FlaxT5EncoderModel",
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
@@ -68,13 +93,28 @@
if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, T5OnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_t5 import T5Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_t5_fast import T5TokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
@@ -84,7 +124,12 @@
load_tf_weights_in_t5,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
@@ -93,8 +138,18 @@
TFT5PreTrainedModel,
)
- if is_flax_available():
- from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_t5 import (
+ FlaxT5EncoderModel,
+ FlaxT5ForConditionalGeneration,
+ FlaxT5Model,
+ FlaxT5PreTrainedModel,
+ )
else:
diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py
index b09539c86d70..a2bd03dfd74c 100644
--- a/src/transformers/models/t5/configuration_t5.py
+++ b/src/transformers/models/t5/configuration_t5.py
@@ -116,6 +116,22 @@ def __init__(
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
+
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
index a00203016822..7d9a20f3b0b3 100755
--- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
@@ -49,8 +49,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained T5 model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py
index 767caea3eb38..06ad51054297 100644
--- a/src/transformers/models/t5/modeling_flax_t5.py
+++ b/src/transformers/models/t5/modeling_flax_t5.py
@@ -87,7 +87,7 @@ def __call__(self, hidden_states):
return self.weight * hidden_states
-class FlaxT5DenseReluDense(nn.Module):
+class FlaxT5DenseActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
@@ -108,16 +108,17 @@ def setup(self):
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.wi(hidden_states)
- hidden_states = jax.nn.relu(hidden_states)
+ hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
-class FlaxT5DenseGatedGeluDense(nn.Module):
+class FlaxT5DenseGatedActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@@ -144,10 +145,10 @@ def setup(self):
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
- self.gelu_act = ACT2FN["gelu_new"]
+ self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic):
- hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
+ hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
@@ -160,14 +161,10 @@ class FlaxT5LayerFF(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
- if self.config.feed_forward_proj == "relu":
- self.DenseReluDense = FlaxT5DenseReluDense(self.config, dtype=self.dtype)
- elif self.config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = FlaxT5DenseGatedGeluDense(self.config, dtype=self.dtype)
+ if self.config.is_gated_act:
+ self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate)
@@ -412,10 +409,11 @@ def __call__(
# replace masked positions with -10_000
if attention_mask is not None:
+ mask_value = jnp.finfo(self.dtype).min
attention_mask = jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
- jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
+ jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
)
if position_bias is None:
@@ -931,18 +929,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
- decoder_input_ids = jnp.ones_like(input_ids)
- decoder_attention_mask = jnp.ones_like(input_ids)
+ args = [input_ids, attention_mask]
+ if self.module_class not in [FlaxT5EncoderModule]:
+ decoder_input_ids = jnp.ones_like(input_ids)
+ decoder_attention_mask = jnp.ones_like(input_ids)
+ args.extend([decoder_input_ids, decoder_attention_mask])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
- input_ids,
- attention_mask,
- decoder_input_ids,
- decoder_attention_mask,
+ *args,
)["params"]
if params is not None:
@@ -977,7 +975,8 @@ def __call__(
if decoder_input_ids is None:
raise ValueError(
- "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
+ " here."
)
# prepare encoder inputs
@@ -1243,7 +1242,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
@add_start_docstrings(
- "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5Module(nn.Module):
@@ -1358,6 +1357,90 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+@add_start_docstrings(
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
+ T5_START_DOCSTRING,
+)
+class FlaxT5EncoderModule(nn.Module):
+ config: T5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
+ )
+
+ encoder_config = copy.deepcopy(self.config)
+ encoder_config.is_decoder = False
+ encoder_config.is_encoder_decoder = False
+ encoder_config.causal = False
+ self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ deterministic: bool = True,
+ ):
+
+ # Encode if needed (training, first prediction pass)
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ return encoder_outputs
+
+
+class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
+ module_class = FlaxT5EncoderModule
+
+ @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py
index 630e9dd17aa5..e4c36109bd77 100644
--- a/src/transformers/models/t5/modeling_t5.py
+++ b/src/transformers/models/t5/modeling_t5.py
@@ -34,7 +34,7 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
@@ -275,34 +275,36 @@ def forward(self, hidden_states):
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
+ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
-class T5DenseReluDense(nn.Module):
+
+class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
- self.relu_act = ACT2FN["relu"]
+ self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
- hidden_states = self.relu_act(hidden_states)
+ hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
-class T5DenseGatedGeluDense(nn.Module):
+class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
- self.gelu_act = ACT2FN["gelu_new"]
+ self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
- hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
+ hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
@@ -313,14 +315,10 @@ def forward(self, hidden_states):
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
- if config.feed_forward_proj == "relu":
- self.DenseReluDense = T5DenseReluDense(config)
- elif config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = T5DenseGatedGeluDense(config)
+ if config.is_gated_act:
+ self.DenseReluDense = T5DenseGatedActDense(config)
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = T5DenseActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@@ -408,26 +406,24 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_postion_if_large = max_exact + (
+ relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
- relative_postion_if_large = torch.min(
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
- relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
- def compute_bias(self, query_length, key_length):
+ def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
- context_position = torch.arange(
- query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
- )[:, None]
- memory_position = torch.arange(
- key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
- )[None, :]
+ if device is None:
+ device = self.relative_attention_bias.weight.device
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
@@ -522,7 +518,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
- position_bias = self.compute_bias(real_seq_length, key_length)
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
@@ -747,6 +743,7 @@ class T5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
+ _no_split_modules = ["T5Block"]
@property
def dummy_inputs(self):
@@ -768,7 +765,9 @@ def _init_weights(self, module):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
- elif isinstance(module, T5DenseReluDense):
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, T5DenseActDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
@@ -778,7 +777,7 @@ def _init_weights(self, module):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
- elif isinstance(module, T5DenseGatedGeluDense):
+ elif isinstance(module, T5DenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
@@ -809,9 +808,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
+ " See T5 docs for more information"
+ )
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
@@ -827,8 +827,6 @@ def _shift_right(self, input_ids):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
-
return shifted_input_ids
@@ -944,7 +942,7 @@ def forward(
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None:
- attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
@@ -1268,11 +1266,11 @@ def custom_forward(*inputs):
)
class T5Model(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
@@ -1410,8 +1408,7 @@ def forward(
)
hidden_states = encoder_outputs[0]
- if self.model_parallel:
- torch.cuda.set_device(self.decoder.first_device)
+
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
@@ -1457,12 +1454,12 @@ def forward(
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
- r"lm_head\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
@@ -1751,7 +1748,7 @@ def _reorder_cache(self, past, beam_idx):
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
def __init__(self, config: T5Config):
diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py
index 3434a6ea4f37..2eebdfd1cb60 100644
--- a/src/transformers/models/t5/modeling_tf_t5.py
+++ b/src/transformers/models/t5/modeling_tf_t5.py
@@ -23,7 +23,7 @@
import numpy as np
import tensorflow as tf
-from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
+from tensorflow.compiler.tf2xla.python.xla import dynamic_slice
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
@@ -93,7 +93,7 @@ def call(self, hidden_states):
return self.weight * hidden_states
-class TFT5DenseReluDense(tf.keras.layers.Layer):
+class TFT5DenseActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
@@ -109,7 +109,7 @@ def __init__(self, config, **kwargs):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
- self.act = tf.keras.activations.relu
+ self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False):
hidden_states = self.wi(hidden_states)
@@ -119,7 +119,7 @@ def call(self, hidden_states, training=False):
return hidden_states
-class TFT5GatedGeluDense(tf.keras.layers.Layer):
+class TFT5DenseGatedActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
@@ -138,7 +138,7 @@ def __init__(self, config, **kwargs):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
- self.act = get_tf_activation("gelu_new")
+ self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
@@ -152,14 +152,11 @@ def call(self, hidden_states, training=False):
class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
- if config.feed_forward_proj == "relu":
- self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
- elif config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
+ if config.is_gated_act:
+ self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense")
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
+
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
@@ -271,7 +268,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.cast(
- tf.math.log(relative_position / max_exact)
+ tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32))
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
dtype=relative_position.dtype,
@@ -388,10 +385,19 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
else:
position_bias = self.compute_bias(real_seq_length, key_length)
- # if key and values are already calculated
- # we want only the last query position bias
+ # if key and values are already calculated we want only the last query position bias
if past_key_value is not None:
- position_bias = position_bias[:, :, -seq_length:, :]
+ if not self.has_relative_attention_bias:
+ position_bias = position_bias[:, :, -seq_length:, :]
+ else:
+ # we might have a padded past structure, in which case we want to fetch the position bias slice
+ # right after the most recently filled past index
+ most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0))
+ position_bias = dynamic_slice(
+ position_bias,
+ (0, 0, most_recently_filled_past_index + 1, 0),
+ (1, self.n_heads, seq_length, real_seq_length),
+ )
if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
@@ -406,7 +412,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.n_heads],
- message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.n_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
@@ -899,9 +908,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the"
+ " pad_token_id. See T5 docs for more information"
+ )
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
@@ -1102,13 +1112,15 @@ def _shift_right(self, input_ids):
@add_start_docstrings(
- "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@@ -1255,8 +1267,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model
-
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@@ -1497,70 +1510,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
}
- def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
- # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
- # quite some duplicated code patterns it seems
- # also the `attention_mask` is currently used in a somewhat hacky to
- # correctly influence the `past_key_values` - not sure if this is the way to go
- # Let's keep that for a future PR.
- past = outputs.past_key_values
- is_past_initialized = model_kwargs.pop("past", None) is not None
- decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
- batch_size = past[0][0].shape[0]
-
- if not is_past_initialized:
- # past[0].shape[3] is seq_length of prompt
- num_padding_values = max_length - past[0][0].shape[2] - 1
-
- padding_values = np.zeros((4, 2), dtype=np.int32)
- padding_values[2, 1] = num_padding_values
- padding_values = tf.constant(padding_values)
-
- new_past = ()
- for past_layer in past:
- new_past_layer = list(past_layer)
- for i in range(len(new_past_layer[:2])):
- new_past_layer[i] = tf.pad(past_layer[i], padding_values)
- new_past += (tuple(new_past_layer),)
-
- # 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
- decoder_attention_mask = tf.concat(
- [
- tf.ones((batch_size, 1), dtype=tf.int32),
- tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
- tf.ones((batch_size, 1), dtype=tf.int32),
- ],
- axis=1,
- )
- else:
- slice_start_base = tf.constant([0, 0, 1, 0])
- decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
- # correct 5 here
- new_past_index = current_pos - 1
-
- new_past = ()
- for past_layer in past:
- new_past_layer = list(past_layer)
- for i in range(len(new_past_layer[:2])):
- update_slice = past_layer[i][:, :, -1:]
- # Write the last slice to the first open location in the padded past array
- # and then truncate the last slice off the array
- new_past_layer[i] = dynamic_update_slice(
- past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
- )
- new_past += (tuple(new_past_layer),)
-
- update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
- decoder_attention_mask = dynamic_update_slice(
- decoder_attention_mask, decoder_attention_mask_update_slice, update_start
- )
-
- # set `attention_mask` and `past`
- model_kwargs["decoder_attention_mask"] = decoder_attention_mask
- model_kwargs["past"] = new_past
-
- return model_kwargs
-
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)
@@ -1590,13 +1539,15 @@ def _reorder_cache(self, past, beam_idx):
@add_start_docstrings(
- "The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py
index 09414ae40772..2dbc788374dc 100644
--- a/src/transformers/models/t5/tokenization_t5.py
+++ b/src/transformers/models/t5/tokenization_t5.py
@@ -131,8 +131,9 @@ def __init__(
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
- "In this case the additional_special_tokens must include the extra_ids tokens"
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
+ " tokens"
)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
@@ -161,11 +162,15 @@ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_l
return init_max_model_length
elif init_max_model_length is None:
warnings.warn(
- f"This tokenizer was incorrectly instantiated with a model max length of {deprecated_max_model_length} which will be corrected in Transformers v5.\n"
- f"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n"
- f"- Be aware that you SHOULD NOT rely on {pretrained_model_name_or_path} automatically truncating your input to {deprecated_max_model_length} when padding/encoding.\n"
- f"- If you want to encode/pad to sequences longer than {deprecated_max_model_length} you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n"
- f"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.",
+ "This tokenizer was incorrectly instantiated with a model max length of"
+ f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
+ " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
+ " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
+ f" {pretrained_model_name_or_path} automatically truncating your input to"
+ f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
+ f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
+ " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
+ " instantiate this tokenizer with `model_max_length` set to your preferred value.",
FutureWarning,
)
@@ -212,7 +217,8 @@ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
- f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
+ " eos tokens being added."
)
return token_ids
else:
diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py
index 77a86810b3f7..41ad306b74e6 100644
--- a/src/transformers/models/t5/tokenization_t5_fast.py
+++ b/src/transformers/models/t5/tokenization_t5_fast.py
@@ -126,8 +126,9 @@ def __init__(
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
- "In this case the additional_special_tokens must include the extra_ids tokens"
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
+ " tokens"
)
super().__init__(
@@ -153,11 +154,15 @@ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_l
return init_max_model_length
elif init_max_model_length is None:
warnings.warn(
- f"This tokenizer was incorrectly instantiated with a model max length of {deprecated_max_model_length} which will be corrected in Transformers v5.\n"
- f"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n"
- f"- Be aware that you SHOULD NOT rely on {pretrained_model_name_or_path} automatically truncating your input to {deprecated_max_model_length} when padding/encoding.\n"
- f"- If you want to encode/pad to sequences longer than {deprecated_max_model_length} you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n"
- f"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.",
+ "This tokenizer was incorrectly instantiated with a model max length of"
+ f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
+ " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
+ " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
+ f" {pretrained_model_name_or_path} automatically truncating your input to"
+ f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
+ f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
+ " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
+ " instantiate this tokenizer with `model_max_length` set to your preferred value.",
FutureWarning,
)
diff --git a/src/transformers/models/tapas/__init__.py b/src/transformers/models/tapas/__init__.py
index 4d3c72b85b32..bbfb09ea0fee 100644
--- a/src/transformers/models/tapas/__init__.py
+++ b/src/transformers/models/tapas/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_tapas": ["TapasTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tapas"] = [
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TapasForMaskedLM",
@@ -36,7 +41,12 @@
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_tapas"] = [
"TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFTapasForMaskedLM",
@@ -51,7 +61,12 @@
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from .tokenization_tapas import TapasTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
@@ -62,7 +77,12 @@
load_tf_weights_in_tapas,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_tapas import (
TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TFTapasForMaskedLM,
diff --git a/src/transformers/models/tapas/configuration_tapas.py b/src/transformers/models/tapas/configuration_tapas.py
index 58fb0c66b73a..71fd5715ef57 100644
--- a/src/transformers/models/tapas/configuration_tapas.py
+++ b/src/transformers/models/tapas/configuration_tapas.py
@@ -27,10 +27,18 @@
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json",
- "google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json",
- "google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json",
- "google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json",
+ "google/tapas-base-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
index 88edacacfddc..2772a7f126ef 100644
--- a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
@@ -120,8 +120,10 @@ def convert_tf_checkpoint_to_pytorch(
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained TAPAS model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained TAPAS model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py
index b0c3786ca05a..0b65e84ca7ac 100644
--- a/src/transformers/models/tapas/modeling_tapas.py
+++ b/src/transformers/models/tapas/modeling_tapas.py
@@ -582,7 +582,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -1430,7 +1431,8 @@ def forward(
per_example_additional_loss *= large_answer_loss_mask
else:
raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the regression loss"
+ "You have to specify numeric values and numeric values scale in order to calculate the"
+ " regression loss"
)
total_loss += torch.mean(per_example_additional_loss)
diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py
index d2da0644627a..93d98914f1f3 100644
--- a/src/transformers/models/tapas/modeling_tf_tapas.py
+++ b/src/transformers/models/tapas/modeling_tf_tapas.py
@@ -519,8 +519,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -862,6 +862,19 @@ class TFTapasPreTrainedModel(TFPreTrainedModel):
config_class = TapasConfig
base_model_prefix = "tapas"
+ @tf.function(
+ input_signature=[
+ {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
+ "token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+ return self.serving_output(output)
+
TAPAS_START_DOCSTRING = r"""
@@ -1021,14 +1034,14 @@ def call(
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
- hidden_states=hs,
- attentions=attns,
+ hidden_states=hidden_states,
+ attentions=attentions,
)
@@ -1128,10 +1141,10 @@ def call(
)
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
- return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
+ return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
class TFTapasComputeTokenLogits(tf.keras.layers.Layer):
@@ -1533,7 +1546,8 @@ def call(
per_example_additional_loss *= large_answer_loss_mask
else:
raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the regression loss"
+ "You have to specify numeric values and numeric values scale in order to calculate the"
+ " regression loss"
)
total_loss += tf.reduce_mean(per_example_additional_loss)
@@ -1556,11 +1570,14 @@ def call(
)
def serving_output(self, output: TFTableQuestionAnsweringOutput) -> TFTableQuestionAnsweringOutput:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTableQuestionAnsweringOutput(
- logits=output.logits, logits_aggregation=output.logits_aggregation, hidden_states=hs, attentions=attns
+ logits=output.logits,
+ logits_aggregation=output.logits_aggregation,
+ hidden_states=hidden_states,
+ attentions=attentions,
)
@@ -1666,10 +1683,10 @@ def call(
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
- return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
+ return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
""" TAPAS utilities."""
@@ -1723,10 +1740,13 @@ def __init__(self, outer_index, inner_index):
inner_index: IndexMap, must have the same shape as `outer_index`.
"""
if outer_index.batch_dims != inner_index.batch_dims:
- raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
+ raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")
super(ProductIndexMap, self).__init__(
- indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
+ indices=(
+ inner_index.indices
+ + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
+ ),
num_segments=inner_index.num_segments * outer_index.num_segments,
batch_dims=inner_index.batch_dims,
)
@@ -1785,7 +1805,7 @@ def flatten(index, name="segmented_flatten"):
for _ in range(index.batch_dims, index.indices.shape.rank):
offset = tf.expand_dims(offset, -1)
- indices = offset + index.indices
+ indices = tf.cast(offset, index.indices.dtype) + index.indices
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)
diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py
index 27481c35fb14..ddb855642f43 100644
--- a/src/transformers/models/tapas/tokenization_tapas.py
+++ b/src/transformers/models/tapas/tokenization_tapas.py
@@ -50,35 +50,83 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
# large models
- "google/tapas-large-finetuned-sqa": "https://huggingface.co/google/tapas-large-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-wtq": "https://huggingface.co/google/tapas-large-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-large-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-tabfact": "https://huggingface.co/google/tapas-large-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-large-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-large-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-large-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-large-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-large-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# base models
- "google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-base-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# medium models
- "google/tapas-medium-finetuned-sqa": "https://huggingface.co/google/tapas-medium-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-wtq": "https://huggingface.co/google/tapas-medium-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-medium-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-tabfact": "https://huggingface.co/google/tapas-medium-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-medium-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-medium-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-medium-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-medium-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-medium-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# small models
- "google/tapas-small-finetuned-sqa": "https://huggingface.co/google/tapas-small-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-wtq": "https://huggingface.co/google/tapas-small-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-small-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-tabfact": "https://huggingface.co/google/tapas-small-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-small-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-small-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-small-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-small-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-small-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# tiny models
- "google/tapas-tiny-finetuned-sqa": "https://huggingface.co/google/tapas-tiny-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-wtq": "https://huggingface.co/google/tapas-tiny-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-tiny-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-tabfact": "https://huggingface.co/google/tapas-tiny-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-tiny-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# mini models
- "google/tapas-mini-finetuned-sqa": "https://huggingface.co/google/tapas-mini-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-wtq": "https://huggingface.co/google/tapas-mini-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-mini-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-tabfact": "https://huggingface.co/google/tapas-mini-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-mini-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-mini-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-mini-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-mini-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-mini-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
}
}
@@ -329,8 +377,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -594,7 +642,8 @@ def __call__(
if not valid_query:
raise ValueError(
- "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized example). "
+ "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized"
+ " example). "
)
is_batched = isinstance(queries, (list, tuple))
@@ -1229,7 +1278,7 @@ def prepare_for_model(
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
- f"Token indices sequence length is longer than the specified maximum sequence length "
+ "Token indices sequence length is longer than the specified maximum sequence length "
f"for this model ({len(encoded_inputs['input_ids'])} > {self.model_max_length}). Running this "
"sequence through the model will result in indexing errors."
)
diff --git a/src/transformers/models/tapex/__init__.py b/src/transformers/models/tapex/__init__.py
index 36c5938d23c9..3b13bed2ca10 100644
--- a/src/transformers/models/tapex/__init__.py
+++ b/src/transformers/models/tapex/__init__.py
@@ -21,9 +21,7 @@
from ...file_utils import _LazyModule
-_import_structure = {
- "tokenization_tapex": ["TapexTokenizer"],
-}
+_import_structure = {"tokenization_tapex": ["TapexTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/tapex/tokenization_tapex.py b/src/transformers/models/tapex/tokenization_tapex.py
index 0b5c1241415a..7c0725ffe7c1 100644
--- a/src/transformers/models/tapex/tokenization_tapex.py
+++ b/src/transformers/models/tapex/tokenization_tapex.py
@@ -17,7 +17,6 @@
import json
import os
import random
-from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union
@@ -63,12 +62,6 @@ class TapexTruncationStrategy(ExplicitEnum):
DROP_ROWS_TO_FIT = "drop_rows_to_fit"
-class TokenizerStrategy(ExplicitEnum):
-
- TOKENIZE_SOURCE = "tokenize_source"
- TOKENIZE_TARGET = "tokenize_target"
-
-
TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to encode the sequences with the special tokens relative to their model.
@@ -341,9 +334,6 @@ def __init__(
self.max_cell_length = max_cell_length
self.table_linearize = IndexedRowTableLinearize()
- # property to decide using which call function
- self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
-
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
@@ -503,7 +493,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
@@ -555,9 +545,7 @@ def __call__(
Optionally, the corresponding answer to the questions as supervision.
"""
- if self.current_tokenizer == TokenizerStrategy.TOKENIZE_SOURCE:
- if table is None:
- raise ValueError("Please ensure that the table is not empty if you use TAPEX to encode source.")
+ if table is not None:
return self.source_call_func(
table=table,
query=query,
@@ -578,9 +566,7 @@ def __call__(
verbose=verbose,
**kwargs,
)
- else:
- if answer is None:
- raise ValueError("Please ensure that the answer is not empty if you use TAPEX to encode target.")
+ elif answer is not None:
return self.target_call_func(
answer=answer,
add_special_tokens=add_special_tokens,
@@ -599,6 +585,8 @@ def __call__(
verbose=verbose,
**kwargs,
)
+ else:
+ raise ValueError("You need to provide either a `table` or an `answer`.")
def source_call_func(
self,
@@ -1330,17 +1318,6 @@ def _target_encode_plus(
verbose=verbose,
)
- @contextmanager
- def as_target_tokenizer(self):
- """
- Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
- sequence-to-sequence models that need a slightly different processing for the labels.
- """
- self.current_tokenizer = TokenizerStrategy.TOKENIZE_TARGET
- yield
- # restore the call function
- self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
-
def prepare_table_query(
self,
table,
diff --git a/src/transformers/models/trajectory_transformer/__init__.py b/src/transformers/models/trajectory_transformer/__init__.py
new file mode 100644
index 000000000000..0b8a6f2c5892
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/__init__.py
@@ -0,0 +1,68 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_trajectory_transformer": [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "TrajectoryTransformerConfig",
+ ],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_trajectory_transformer"] = [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ "load_tf_weights_in_trajectory_transformer",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ TrajectoryTransformerConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TrajectoryTransformerModel,
+ TrajectoryTransformerPreTrainedModel,
+ load_tf_weights_in_trajectory_transformer,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
new file mode 100644
index 000000000000..537a467c7016
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
@@ -0,0 +1,167 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TrajectoryTransformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2": (
+ "https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2/resolve/main/config.json"
+ ),
+ # See all TrajectoryTransformer models at https://huggingface.co/models?filter=trajectory_transformer
+}
+
+
+class TrajectoryTransformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TrajectoryTransformerModel`]. It is used to
+ instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ TrajectoryTransformer
+ [CarlCochet/trajectory-transformer-halfcheetah-medium-v2](https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 100):
+ Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be
+ represented by the `trajectories` passed when calling [`TrajectoryTransformerModel`]
+ batch_size (`int`, *optional*, defaults to 256):
+ Size of the batch of trajectories passed to the model.
+ action_weight (`int`, *optional*, defaults to 5):
+ Weight of the action in the loss function
+ reward_weight (`int`, *optional*, defaults to 1):
+ Weight of the reward in the loss function
+ value_weight (`int`, *optional*, defaults to 1):
+ Weight of the value in the loss function
+ block_size (`int`, *optional*, defaults to 249):
+ Size of the blocks in the trajectory transformer.
+ action_dim (`int`, *optional*, defaults to 6):
+ Dimension of the action space.
+ observation_dim (`int`, *optional*, defaults to 17):
+ Dimension of the observation space.
+ transition_dim (`int`, *optional*, defaults to 25):
+ Dimension of the transition space.
+ n_layer (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_embd (`int`, *optional*, defaults to 128):
+ Dimensionality of the embeddings and hidden states.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`TrajectoryTransformerModel`]
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ kaiming_initializer_range (`float, *optional*, defaults to 1):
+ A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ Example:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerModel, TrajectoryTransformerConfig
+
+ >>> # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> configuration = TrajectoryTransformerConfig()
+
+ >>> # Initializing a model from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> model = TrajectoryTransformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "trajectory_transformer"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=100,
+ batch_size=256,
+ action_weight=5,
+ reward_weight=1,
+ value_weight=1,
+ block_size=249,
+ action_dim=6,
+ observation_dim=17,
+ transition_dim=25,
+ n_layer=4,
+ n_head=4,
+ n_embd=128,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ resid_pdrop=0.1,
+ learning_rate=0.0006,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ kaiming_initializer_range=1,
+ use_cache=True,
+ is_encoder_decoder=False,
+ pad_token_id=1,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ **kwargs
+ ):
+ self.vocab_size = vocab_size
+ self.batch_size = batch_size
+ self.action_weight = action_weight
+ self.reward_weight = reward_weight
+ self.value_weight = value_weight
+ self.max_position_embeddings = max_position_embeddings
+ self.block_size = block_size
+ self.action_dim = action_dim
+ self.observation_dim = observation_dim
+ self.transition_dim = transition_dim
+ self.learning_rate = learning_rate
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_embd = n_embd
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.resid_pdrop = resid_pdrop
+ self.initializer_range = initializer_range
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+ self.kaiming_initializer_range = kaiming_initializer_range
+ self.use_cache = use_cache
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 000000000000..14e6556e07b7
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,70 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TrajectoryTransformer pytorch checkpoint conversion"""
+
+import torch
+
+import trajectory.utils as utils
+from transformers import TrajectoryTransformerModel
+
+
+class Parser(utils.Parser):
+ dataset: str = "halfcheetah-medium-expert-v2"
+ config: str = "config.offline"
+
+
+def convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(logbase, dataset, loadpath, epoch, device):
+ """Converting Sequential blocks to ModuleList"""
+
+ gpt, gpt_epoch = utils.load_model(logbase, dataset, loadpath, epoch=epoch, device=device)
+ trajectory_transformer = TrajectoryTransformerModel(gpt.config)
+
+ trajectory_transformer.tok_emb.load_state_dict(gpt.tok_emb.state_dict())
+ trajectory_transformer.pos_emb = gpt.pos_emb
+ trajectory_transformer.drop.load_state_dict(gpt.drop.state_dict())
+ trajectory_transformer.ln_f.load_state_dict(gpt.ln_f.state_dict())
+ trajectory_transformer.head.load_state_dict(gpt.head.state_dict())
+
+ for i, block in enumerate(gpt.blocks):
+ trajectory_transformer.blocks[i].ln1.load_state_dict(gpt.blocks[i].ln1.state_dict())
+ trajectory_transformer.blocks[i].ln2.load_state_dict(gpt.blocks[i].ln2.state_dict())
+ trajectory_transformer.blocks[i].attn.load_state_dict(gpt.blocks[i].attn.state_dict())
+
+ trajectory_transformer.blocks[i].l1.load_state_dict(gpt.blocks[i].mlp[0].state_dict())
+ trajectory_transformer.blocks[i].act.load_state_dict(gpt.blocks[i].mlp[1].state_dict())
+ trajectory_transformer.blocks[i].l2.load_state_dict(gpt.blocks[i].mlp[2].state_dict())
+ trajectory_transformer.blocks[i].drop.load_state_dict(gpt.blocks[i].mlp[3].state_dict())
+
+ torch.save(trajectory_transformer.state_dict(), "pytorch_model.bin")
+
+
+if __name__ == "__main__":
+ """
+ To run this script you will need to install the original repository to run the original model. You can find it
+ here: https://github.com/jannerm/trajectory-transformer From this repository code you can also download the
+ original pytorch checkpoints.
+
+ Run with the command:
+
+ ```sh
+ >>> python convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py --dataset
+ ... --gpt_loadpath
+ ```
+ """
+
+ args = Parser().parse_args("plan")
+ convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(
+ args.logbase, args.dataset, args.gpt_loadpath, args.gpt_epoch, args.device
+ )
diff --git a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py
new file mode 100644
index 000000000000..b2c14029a074
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py
@@ -0,0 +1,617 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch TrajectoryTransformer model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import functional as F
+
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_trajectory_transformer import TrajectoryTransformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+_CONFIG_FOR_DOC = "TrajectoryTransformerConfig"
+
+TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2",
+ # See all TrajectoryTransformer models at https://huggingface.co/models?filter=trajectory_transformer
+]
+
+
+def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+@dataclass
+class TrajectoryTransformerOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
+ sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the
+ attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. GPT2Attentions weights after the attention softmax, used to compute the weighted average
+ in the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TrajectoryTransformerConfig
+ load_tf_weights = load_tf_weights_in_trajectory_transformer
+ base_model_prefix = "trajectory_transformer"
+ main_input_name = "trajectories"
+ supports_gradient_checkpointing = True
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, TrajectoryTransformerModel):
+ module.gradient_checkpointing = value
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, EinLinear):
+ for i in range(module.n_models):
+ nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range)
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i])
+ bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range
+ nn.init.uniform_(module.bias[i], -bound, bound)
+
+
+TRAJECTORY_TRANSFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`TrajectoryTransformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ trajectories (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Batch of trajectories, where a trajectory is a sequence of states, actions and rewards.
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`, *optional*):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ targets (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Desired targets used to compute the loss.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class EinLinear(nn.Module):
+ def __init__(self, n_models, in_features, out_features, bias):
+ super().__init__()
+ self.n_models = n_models
+ self.out_features = out_features
+ self.in_features = in_features
+ self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(n_models, out_features))
+ else:
+ self.register_parameter("bias", None)
+
+ def reset_parameters(self):
+ for i in range(self.n_models):
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias[i], -bound, bound)
+
+ def forward(self, input):
+ """
+ Args:
+ input (`torch.FloatTensor` of shape `(B, n_models, input_dim)`):
+ The input to the layer.
+ """
+ # [ batch_size x n_models x output_dim ]
+ output = torch.einsum("eoi,bei->beo", self.weight, input)
+ if self.bias is not None:
+ raise RuntimeError()
+ return output
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ if config.n_embd % config.n_head != 0:
+ raise ValueError(f"n_head ({config.n_head}) should be a divisor of n_embd ({config.n_embd})")
+
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "mask",
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
+ 1, 1, config.block_size, config.block_size
+ ),
+ )
+
+ # mask previous value estimates
+ joined_dim = config.observation_dim + config.action_dim + 2
+ self.mask.squeeze()[:, joined_dim - 1 :: joined_dim] = 0
+
+ self.n_head = config.n_head
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ batch_size, sequence_length, embedding_dim = hidden_states.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ # [ batch_size x n_heads x sequence_length x head_dim ]
+ key = (
+ self.key(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ query = (
+ self.query(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ value = (
+ self.value(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ # causal self-attention
+ # [ batch_size x n_heads x sequence_length x sequence_length ]
+ attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1)))
+ attn_weights = attn_weights.masked_fill(
+ self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min
+ )
+ attn_weights = F.softmax(attn_weights, dim=-1)
+ self._attn_map = attn_weights.clone()
+ attn_weights = self.attn_drop(attn_weights)
+
+ output = torch.matmul(attn_weights, value)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # re-assemble all head outputs side by side
+ output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, embedding_dim)
+
+ # output projection
+ output = self.resid_drop(self.proj(output))
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Block(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+
+ # MLP
+ self.l1 = nn.Linear(config.n_embd, 4 * config.n_embd)
+ self.act = nn.GELU()
+ self.l2 = nn.Linear(4 * config.n_embd, config.n_embd)
+ self.drop = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln1(hidden_states)
+
+ attn_outputs = self.attn(
+ hidden_states, layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions
+ )
+ attn_output = attn_outputs[0]
+ outputs = attn_outputs[1:]
+ hidden_states = attn_output + residual
+
+ residual = hidden_states
+ hidden_states = self.ln2(hidden_states)
+ hidden_states = self.l1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.l2(hidden_states)
+ hidden_states = residual + self.drop(hidden_states)
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs
+
+
+@add_start_docstrings(
+ "The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top.",
+ TRAJECTORY_TRANSFORMER_START_DOCSTRING,
+)
+class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
+ """the full GPT language model, with a context size of block_size"""
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ # input embedding stem (+1 for stop token)
+ self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd)
+
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False)
+
+ self.vocab_size = config.vocab_size
+ self.stop_token = config.vocab_size * config.transition_dim
+ self.block_size = config.block_size
+
+ self.observation_dim = config.observation_dim
+ self.action_dim = config.action_dim
+ self.transition_dim = config.transition_dim
+ self.embedding_dim = config.n_embd
+
+ self.action_weight = config.action_weight
+ self.reward_weight = config.reward_weight
+ self.value_weight = config.value_weight
+
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ def get_block_size(self):
+ return self.block_size
+
+ def offset_tokens(self, trajectories):
+ _, sequence_length = trajectories.shape
+
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+
+ offsets = torch.arange(self.transition_dim) * self.vocab_size
+ offsets = offsets.repeat(n_states).to(trajectories.device)
+
+ offset_trajectories = trajectories + offsets[:sequence_length]
+ offset_trajectories[trajectories == self.vocab_size] = self.stop_token
+ return offset_trajectories
+
+ def pad_to_full_observation(self, hidden_states):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ n_pad = (self.transition_dim - sequence_length % self.transition_dim) % self.transition_dim
+ padding = torch.zeros(batch_size, n_pad, self.embedding_dim, device=hidden_states.device)
+
+ # [ batch_size x padded_sequence_length' x embedding_dim ]
+ hidden_states_pad = torch.cat([hidden_states, padding], dim=1)
+ hidden_states_pad = hidden_states_pad.view(-1, self.transition_dim, self.embedding_dim)
+
+ return hidden_states_pad, n_pad
+
+ @add_start_docstrings_to_model_forward(
+ TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(output_type=TrajectoryTransformerOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ trajectories: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ targets: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerModel
+ >>> import torch
+
+ >>> model = TrajectoryTransformerModel.from_pretrained(
+ ... "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+ ... )
+ >>> model.to(device)
+ >>> model.eval()
+
+ >>> observations_dim, action_dim, batch_size = 17, 6, 256
+ >>> seq_length = observations_dim + action_dim + 1
+
+ >>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(
+ ... device
+ ... )
+ >>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device)
+
+ >>> outputs = model(
+ ... trajectories,
+ ... targets=targets,
+ ... use_cache=True,
+ ... output_attentions=True,
+ ... output_hidden_states=True,
+ ... return_dict=True,
+ ... )
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.blocks))
+
+ batch_size, sequence_length = trajectories.size()
+
+ if sequence_length > self.block_size:
+ raise ValueError("Cannot forward, model block size is exhausted.")
+
+ offset_trajectories = self.offset_tokens(trajectories)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # forward the GPT model
+ token_embeddings = self.tok_emb(offset_trajectories) # each index maps to a (learnable) vector
+ position_embeddings = self.pos_emb[:, :sequence_length, :] # each position maps to a (learnable) vector
+
+ hidden_states = self.drop(token_embeddings + position_embeddings)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ layer_past,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ outputs = block(hidden_states, layer_past, use_cache, output_attentions)
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # [ batch_size x sequence_length x embedding_dim ]
+ hidden_state = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states_pad, n_pad = self.pad_to_full_observation(hidden_state)
+
+ logits = self.head(hidden_states_pad)
+ logits = logits.reshape(batch_size, sequence_length + n_pad, self.vocab_size + 1)
+ logits = logits[:, :sequence_length]
+
+ # if we are given some desired targets also calculate the loss
+ if targets is not None:
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction="none")
+ if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
+ # make weights
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+ weights = torch.cat(
+ [
+ torch.ones(self.observation_dim, device=trajectories.device),
+ torch.ones(self.action_dim, device=trajectories.device) * self.action_weight,
+ torch.ones(1, device=trajectories.device) * self.reward_weight,
+ torch.ones(1, device=trajectories.device) * self.value_weight,
+ ]
+ )
+ weights = weights.repeat(n_states)
+ weights = weights[1:].repeat(batch_size, 1)
+ loss = loss * weights.view(-1)
+ loss = (loss * attention_mask.view(-1)).mean()
+ else:
+ loss = None
+
+ if not return_dict:
+ return tuple(v for v in [loss, logits, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TrajectoryTransformerOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
diff --git a/src/transformers/models/transfo_xl/__init__.py b/src/transformers/models/transfo_xl/__init__.py
index ed01124a4905..672ad9afc527 100644
--- a/src/transformers/models/transfo_xl/__init__.py
+++ b/src/transformers/models/transfo_xl/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_transfo_xl": ["TransfoXLCorpus", "TransfoXLTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_transfo_xl"] = [
"TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"AdaptiveEmbedding",
@@ -37,7 +42,12 @@
"load_tf_weights_in_transfo_xl",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_transfo_xl"] = [
"TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFAdaptiveEmbedding",
@@ -53,7 +63,12 @@
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,
@@ -64,7 +79,12 @@
load_tf_weights_in_transfo_xl,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
diff --git a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
index abde04bd43c7..646c8a2342fc 100755
--- a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
@@ -101,8 +101,10 @@ def convert_transfo_xl_checkpoint_to_pytorch(
"--transfo_xl_config_file",
default="",
type=str,
- help="An optional config json file corresponding to the pre-trained BERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "An optional config json file corresponding to the pre-trained BERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--transfo_xl_dataset_file",
diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
index 29753738839c..66467350f142 100644
--- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
@@ -935,9 +935,10 @@ def __init__(self, config):
super().__init__(config)
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
self.sample_softmax = config.sample_softmax
- assert (
- self.sample_softmax <= 0
- ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
self.crit = TFAdaptiveSoftmaxMask(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
@@ -1126,7 +1127,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
index af95f348ec28..dcfa84d0f94b 100644
--- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
+++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
@@ -111,7 +111,7 @@ def _logit(x, W, b, proj=None):
@staticmethod
def _gather_logprob(logprob, target):
lp_size = shape_list(logprob)
- r = tf.range(lp_size[0])
+ r = tf.range(lp_size[0], dtype=target.dtype)
idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx)
diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py
index 556525cbf6c8..75793466c7a8 100644
--- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py
@@ -327,21 +327,17 @@ def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attent
attn_score = AC + BD
attn_score.mul_(self.scale)
+ mask_value = torch.finfo(attn_score.dtype).min
+
# compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
- if next(self.parameters()).dtype == torch.float16:
- attn_score = (
- attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
- )
- else:
- attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
+ attn_score = (
+ attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score)
+ )
elif attn_mask.dim() == 3:
- if next(self.parameters()).dtype == torch.float16:
- attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
- else:
- attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
+ attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = nn.functional.softmax(attn_score, dim=1)
@@ -1020,13 +1016,15 @@ def __init__(self, config):
if not self.trainer_compatible:
warnings.warn(
"The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order"
- "to use that updated output, please specify `trainer_compatible=True` as your configuration attribute.",
+ "to use that updated output, please specify `trainer_compatible=True` as your configuration"
+ " attribute.",
DeprecationWarning,
)
- assert (
- self.sample_softmax <= 0
- ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
self.crit = ProjectedAdaptiveLogSoftmax(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
@@ -1196,7 +1194,7 @@ def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[tor
TRANSFO_XL_START_DOCSTRING,
)
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -1261,7 +1259,7 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
index b25dc2d707d6..e25ba2cd476a 100644
--- a/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
+++ b/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
@@ -102,7 +102,7 @@ def forward(self, hidden, labels=None, keep_order=False):
hidden = hidden.view(-1, hidden.size(-1))
labels = labels.view(-1)
if hidden.size(0) != labels.size(0):
- raise RuntimeError("Input and labels should have the same size " "in the batch dimension.")
+ raise RuntimeError("Input and labels should have the same size in the batch dimension.")
else:
hidden = hidden.view(-1, hidden.size(-1))
diff --git a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
index 115cd4fdcfca..5b284a219a47 100644
--- a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
@@ -27,10 +27,19 @@
import numpy as np
-import sacremoses as sm
-
from ...tokenization_utils import PreTrainedTokenizer
-from ...utils import cached_path, is_torch_available, logging, torch_only_method
+from ...utils import (
+ cached_file,
+ is_sacremoses_available,
+ is_torch_available,
+ logging,
+ requires_backends,
+ torch_only_method,
+)
+
+
+if is_sacremoses_available():
+ import sacremoses as sm
if is_torch_available():
@@ -187,6 +196,7 @@ def __init__(
language=language,
**kwargs,
)
+ requires_backends(self, "sacremoses")
if never_split is None:
never_split = self.all_special_tokens
@@ -671,25 +681,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
Instantiate a pre-processed corpus.
"""
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
- if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
- corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
- else:
- corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
# redirect to the cache, if necessary
try:
- resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
+ resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
- f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list "
- f"({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. "
- f"We assumed '{pretrained_model_name_or_path}' was a path or url but couldn't find files {corpus_file} "
- "at this path or url."
+ f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
+ f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
+ f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
)
return None
- if resolved_corpus_file == corpus_file:
- logger.info(f"loading corpus file {corpus_file}")
+ if is_local:
+ logger.info(f"loading corpus file {resolved_corpus_file}")
else:
- logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}")
+ logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
diff --git a/src/transformers/models/trocr/__init__.py b/src/transformers/models/trocr/__init__.py
index 5f9f462e1839..8e18eaeb4069 100644
--- a/src/transformers/models/trocr/__init__.py
+++ b/src/transformers/models/trocr/__init__.py
@@ -17,19 +17,27 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_trocr": [
- "TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "TrOCRConfig",
- ],
+ "configuration_trocr": ["TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", "TrOCRConfig"],
"processing_trocr": ["TrOCRProcessor"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_trocr"] = [
"TROCR_PRETRAINED_MODEL_ARCHIVE_LIST",
"TrOCRForCausalLM",
@@ -41,7 +49,12 @@
from .configuration_trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig
from .processing_trocr import TrOCRProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRForCausalLM, TrOCRPreTrainedModel
else:
diff --git a/src/transformers/models/trocr/configuration_trocr.py b/src/transformers/models/trocr/configuration_trocr.py
index fc878da26d51..a635e6b9b097 100644
--- a/src/transformers/models/trocr/configuration_trocr.py
+++ b/src/transformers/models/trocr/configuration_trocr.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/trocr-base-handwritten": "https://huggingface.co/microsoft/trocr-base-handwritten/resolve/main/config.json",
+ "microsoft/trocr-base-handwritten": (
+ "https://huggingface.co/microsoft/trocr-base-handwritten/resolve/main/config.json"
+ ),
# See all TrOCR models at https://huggingface.co/models?filter=trocr
}
diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py
index 75e015f98848..a79e5e901d67 100644
--- a/src/transformers/models/trocr/modeling_trocr.py
+++ b/src/transformers/models/trocr/modeling_trocr.py
@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
@@ -182,7 +182,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if not (self.head_dim * num_heads == self.embed_dim):
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
@@ -254,7 +255,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -270,7 +272,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -291,7 +294,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -520,7 +524,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -667,7 +671,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -683,7 +688,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -769,7 +775,8 @@ def forward(self, *args, **kwargs):
@add_start_docstrings(
- "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and [`VisionEncoderDecoder`].",
+ "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
+ " [`VisionEncoderDecoder`].",
TROCR_START_DOCSTRING,
)
class TrOCRForCausalLM(TrOCRPreTrainedModel):
diff --git a/src/transformers/models/trocr/processing_trocr.py b/src/transformers/models/trocr/processing_trocr.py
index 2c7893a0915b..752986243f82 100644
--- a/src/transformers/models/trocr/processing_trocr.py
+++ b/src/transformers/models/trocr/processing_trocr.py
@@ -15,6 +15,7 @@
"""
Processor class for TrOCR.
"""
+import warnings
from contextlib import contextmanager
from ...processing_utils import ProcessorMixin
@@ -40,6 +41,7 @@ class TrOCRProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
"""
@@ -48,7 +50,31 @@ def __call__(self, *args, **kwargs):
[`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's
[`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
- return self.current_processor(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ images = kwargs.pop("images", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ images = args[0]
+ args = args[1:]
+
+ if images is None and text is None:
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
+
+ if images is not None:
+ inputs = self.feature_extractor(images, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif images is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
def batch_decode(self, *args, **kwargs):
"""
@@ -69,6 +95,13 @@ def as_target_processor(self):
"""
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
"""
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your images inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/unispeech/__init__.py b/src/transformers/models/unispeech/__init__.py
index 537b125ec0ef..3713e7d8a11c 100644
--- a/src/transformers/models/unispeech/__init__.py
+++ b/src/transformers/models/unispeech/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig"],
-}
+_import_structure = {"configuration_unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_unispeech"] = [
"UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechForCTC",
@@ -37,7 +46,12 @@
if TYPE_CHECKING:
from .configuration_unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_unispeech import (
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechForCTC,
diff --git a/src/transformers/models/unispeech/configuration_unispeech.py b/src/transformers/models/unispeech/configuration_unispeech.py
index 85b998592094..0c687356de03 100644
--- a/src/transformers/models/unispeech/configuration_unispeech.py
+++ b/src/transformers/models/unispeech/configuration_unispeech.py
@@ -24,7 +24,9 @@
logger = logging.get_logger(__name__)
UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/unispeech-large-1500h-cv": "https://huggingface.co/microsoft/unispeech-large-1500h-cv/resolve/main/config.json",
+ "microsoft/unispeech-large-1500h-cv": (
+ "https://huggingface.co/microsoft/unispeech-large-1500h-cv/resolve/main/config.json"
+ ),
# See all UniSpeech models at https://huggingface.co/models?filter=unispeech
}
@@ -78,15 +80,15 @@ class UniSpeechConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
@@ -261,10 +263,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
index 83f051627cc3..bf729309515e 100644
--- a/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
@@ -84,9 +84,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type, is_finetuned
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -154,28 +155,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index 61359bf032f0..dc194318e999 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
@@ -71,35 +71,6 @@
]
-@dataclass
-class UniSpeechBaseModelOutput(ModelOutput):
- """
- Output type of [`UniSpeechBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
"""
@@ -239,7 +210,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -554,7 +525,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -570,7 +542,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -591,7 +564,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -723,10 +697,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -811,10 +787,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -888,7 +866,8 @@ def __init__(self, config):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
+ f" {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1154,7 +1133,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=UniSpeechBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1167,7 +1146,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, UniSpeechBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1199,7 +1178,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return UniSpeechBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
diff --git a/src/transformers/models/unispeech_sat/__init__.py b/src/transformers/models/unispeech_sat/__init__.py
index 75a7397ff7e4..d4a5e179539a 100644
--- a/src/transformers/models/unispeech_sat/__init__.py
+++ b/src/transformers/models/unispeech_sat/__init__.py
@@ -17,14 +17,25 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
"configuration_unispeech_sat": ["UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechSatConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_unispeech_sat"] = [
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechSatForAudioFrameClassification",
@@ -39,7 +50,12 @@
if TYPE_CHECKING:
from .configuration_unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_unispeech_sat import (
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechSatForAudioFrameClassification,
diff --git a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
index b88d9cf91fc9..3205bbc2cca8 100644
--- a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
@@ -24,7 +24,9 @@
logger = logging.get_logger(__name__)
UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/unispeech-sat-base-100h-libri-ft": "https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft/resolve/main/config.json",
+ "microsoft/unispeech-sat-base-100h-libri-ft": (
+ "https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft/resolve/main/config.json"
+ ),
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
}
@@ -79,15 +81,15 @@ class UniSpeechSatConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
@@ -157,13 +159,13 @@ class UniSpeechSatConfig(PretrainedConfig):
instance of [`UniSpeechSatForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -273,10 +275,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
index 78a541d7ed49..93750b64cc3a 100644
--- a/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
@@ -72,7 +72,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
if hf_shape != value.shape:
raise ValueError(
- f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
)
if weight_type == "weight":
@@ -146,14 +147,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
@@ -161,14 +164,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 1812cd65237e..926464d3bf8e 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -27,7 +27,14 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
@@ -77,35 +84,6 @@
]
-@dataclass
-class UniSpeechSatBaseModelOutput(ModelOutput):
- """
- Output type of [`UniSpeechSatBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
"""
@@ -143,38 +121,6 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@@ -278,7 +224,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -593,7 +539,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -609,7 +556,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -630,7 +578,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -762,10 +711,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -850,10 +801,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -927,7 +880,8 @@ def __init__(self, config):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
+ f" {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1194,7 +1148,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=UniSpeechSatBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1207,7 +1161,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1239,7 +1193,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return UniSpeechSatBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1651,13 +1605,15 @@ def __init__(self, config):
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
- "Audio frame classification does not support the use of UniSpeechSat adapters (config.add_adapter=True)"
+ "Audio frame classification does not support the use of UniSpeechSat adapters"
+ " (config.add_adapter=True)"
)
self.unispeech_sat = UniSpeechSatModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1701,6 +1657,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1733,12 +1690,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/van/__init__.py b/src/transformers/models/van/__init__.py
index 73e2752b1f2e..44c88f0448c3 100644
--- a/src/transformers/models/van/__init__.py
+++ b/src/transformers/models/van/__init__.py
@@ -18,15 +18,18 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
-}
+_import_structure = {"configuration_van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_van"] = [
"VAN_PRETRAINED_MODEL_ARCHIVE_LIST",
"VanForImageClassification",
@@ -37,7 +40,12 @@
if TYPE_CHECKING:
from .configuration_van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_van import (
VAN_PRETRAINED_MODEL_ARCHIVE_LIST,
VanForImageClassification,
diff --git a/src/transformers/models/van/configuration_van.py b/src/transformers/models/van/configuration_van.py
index 6d4becdf552b..47d5a9b6c11a 100644
--- a/src/transformers/models/van/configuration_van.py
+++ b/src/transformers/models/van/configuration_van.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
VAN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "Visual-Attention-Network/van-base": "https://huggingface.co/Visual-Attention-Network/van-base/blob/main/config.json",
+ "Visual-Attention-Network/van-base": (
+ "https://huggingface.co/Visual-Attention-Network/van-base/blob/main/config.json"
+ ),
}
diff --git a/src/transformers/models/van/convert_van_to_pytorch.py b/src/transformers/models/van/convert_van_to_pytorch.py
index cb79c82c5c9e..e2c0c95e6450 100644
--- a/src/transformers/models/van/convert_van_to_pytorch.py
+++ b/src/transformers/models/van/convert_van_to_pytorch.py
@@ -85,7 +85,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced):
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -208,10 +209,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
}
names_to_original_checkpoints = {
- "van-tiny": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar",
- "van-small": "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar",
- "van-base": "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar",
- "van-large": "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar",
+ "van-tiny": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar"
+ ),
+ "van-small": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar"
+ ),
+ "van-base": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar"
+ ),
+ "van-large": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar"
+ ),
}
if model_name:
@@ -242,7 +251,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--model-name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: van-tiny/small/base/large. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported resnet* architecture,"
+ " currently: van-tiny/small/base/large. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
@@ -255,7 +267,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--van_dir",
required=True,
type=Path,
- help="A path to VAN's original implementation directory. You can download from here: https://github.com/Visual-Attention-Network/VAN-Classification",
+ help=(
+ "A path to VAN's original implementation directory. You can download from here:"
+ " https://github.com/Visual-Attention-Network/VAN-Classification"
+ ),
)
parser.add_argument(
"--push_to_hub",
diff --git a/src/transformers/models/van/modeling_van.py b/src/transformers/models/van/modeling_van.py
index 7a7030c2f569..5e212d5f485d 100644
--- a/src/transformers/models/van/modeling_van.py
+++ b/src/transformers/models/van/modeling_van.py
@@ -54,23 +54,24 @@
]
-# Stochastic depth implementation
-# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+# Copied from transformers.models.convnext.modeling_convnext.drop_path
+def drop_path(input, drop_prob: float = 0.0, training: bool = False):
"""
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
- DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
- Connect' is a different form of dropout in a separate paper... See discussion:
- https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
- argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
"""
if drop_prob == 0.0 or not training:
- return x
+ return input
keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
+ output = input.div(keep_prob) * random_tensor
return output
@@ -78,15 +79,18 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class VanDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None):
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
-class VanOverlappingPatchEmbedder(nn.Sequential):
+class VanOverlappingPatchEmbedder(nn.Module):
"""
Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
@@ -100,8 +104,13 @@ def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stri
)
self.normalization = nn.BatchNorm2d(hidden_size)
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
-class VanMlpLayer(nn.Sequential):
+class VanMlpLayer(nn.Module):
"""
MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
Transformer](https://arxiv.org/abs/2106.13797).
@@ -123,8 +132,17 @@ def __init__(
self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
self.dropout2 = nn.Dropout(dropout_rate)
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.in_dense(hidden_state)
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.dropout1(hidden_state)
+ hidden_state = self.out_dense(hidden_state)
+ hidden_state = self.dropout2(hidden_state)
+ return hidden_state
+
-class VanLargeKernelAttention(nn.Sequential):
+class VanLargeKernelAttention(nn.Module):
"""
Basic Large Kernel Attention (LKA).
"""
@@ -137,6 +155,12 @@ def __init__(self, hidden_size: int):
)
self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.depth_wise_dilated(hidden_state)
+ hidden_state = self.point_wise(hidden_state)
+ return hidden_state
+
class VanLargeKernelAttentionLayer(nn.Module):
"""
@@ -395,7 +419,8 @@ def _set_gradient_checkpointing(self, module, value=False):
@add_start_docstrings(
- "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding layer.",
+ "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding"
+ " layer.",
VAN_START_DOCSTRING,
)
class VanModel(VanPreTrainedModel):
diff --git a/src/transformers/models/videomae/__init__.py b/src/transformers/models/videomae/__init__.py
new file mode 100644
index 000000000000..fb239c6063ba
--- /dev/null
+++ b/src/transformers/models/videomae/__init__.py
@@ -0,0 +1,77 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_videomae": ["VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VideoMAEConfig"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_videomae"] = [
+ "VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "VideoMAEForPreTraining",
+ "VideoMAEModel",
+ "VideoMAEPreTrainedModel",
+ "VideoMAEForVideoClassification",
+ ]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_videomae"] = ["VideoMAEFeatureExtractor"]
+
+if TYPE_CHECKING:
+ from .configuration_videomae import VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP, VideoMAEConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_videomae import (
+ VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST,
+ VideoMAEForPreTraining,
+ VideoMAEForVideoClassification,
+ VideoMAEModel,
+ VideoMAEPreTrainedModel,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_videomae import VideoMAEFeatureExtractor
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/videomae/configuration_videomae.py b/src/transformers/models/videomae/configuration_videomae.py
new file mode 100644
index 000000000000..932c4c1d98ca
--- /dev/null
+++ b/src/transformers/models/videomae/configuration_videomae.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" VideoMAE model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "MCG-NJU/videomae-base": "https://huggingface.co/MCG-NJU/videomae-base/resolve/main/config.json",
+}
+
+
+class VideoMAEConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`VideoMAEModel`]. It is used to instantiate a
+ VideoMAE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the VideoMAE
+ [MCG-NJU/videomae-base](https://huggingface.co/MCG-NJU/videomae-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of frames in each video.
+ tubelet_size (`int`, *optional*, defaults to 2):
+ The number of tubelets.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
+ Whether to mean pool the final hidden states instead of using the final hidden state of the [CLS] token.
+ decoder_num_attention_heads (`int`, *optional*, defaults to 6):
+ Number of attention heads for each attention layer in the decoder.
+ decoder_hidden_size (`int`, *optional*, defaults to 384):
+ Dimensionality of the decoder.
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the decoder.
+ decoder_intermediate_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
+ norm_pix_loss (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the target patch pixels.
+
+ Example:
+
+ ```python
+ >>> from transformers import VideoMAEConfig, VideoMAEModel
+
+ >>> # Initializing a VideoMAE videomae-base style configuration
+ >>> configuration = VideoMAEConfig()
+
+ >>> # Randomly initializing a model from the configuration
+ >>> model = VideoMAEModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "videomae"
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ num_frames=16,
+ tubelet_size=2,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ use_mean_pooling=True,
+ decoder_num_attention_heads=6,
+ decoder_hidden_size=384,
+ decoder_num_hidden_layers=4,
+ decoder_intermediate_size=1536,
+ norm_pix_loss=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_frames = num_frames
+ self.tubelet_size = tubelet_size
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.use_mean_pooling = use_mean_pooling
+
+ self.decoder_num_attention_heads = decoder_num_attention_heads
+ self.decoder_hidden_size = decoder_hidden_size
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
+ self.decoder_intermediate_size = decoder_intermediate_size
+ self.norm_pix_loss = norm_pix_loss
diff --git a/src/transformers/models/videomae/convert_videomae_to_pytorch.py b/src/transformers/models/videomae/convert_videomae_to_pytorch.py
new file mode 100644
index 000000000000..60e5ae8f5f41
--- /dev/null
+++ b/src/transformers/models/videomae/convert_videomae_to_pytorch.py
@@ -0,0 +1,286 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert VideoMAE checkpoints from the original repository: https://github.com/MCG-NJU/VideoMAE"""
+
+import argparse
+import json
+
+import numpy as np
+import torch
+
+import gdown
+from huggingface_hub import hf_hub_download
+from transformers import (
+ VideoMAEConfig,
+ VideoMAEFeatureExtractor,
+ VideoMAEForPreTraining,
+ VideoMAEForVideoClassification,
+)
+
+
+def get_videomae_config(model_name):
+ config = VideoMAEConfig()
+
+ if "large" in model_name:
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ config.decoder_num_hidden_layers = 12
+ config.decoder_num_attention_heads = 8
+ config.decoder_hidden_size = 512
+ config.decoder_intermediate_size = 2048
+
+ if "finetuned" not in model_name:
+ config.use_mean_pooling = False
+
+ if "finetuned" in model_name:
+ repo_id = "datasets/huggingface/label-files"
+ if "kinetics" in model_name:
+ config.num_labels = 400
+ filename = "kinetics400-id2label.json"
+ elif "ssv2" in model_name:
+ config.num_labels = 174
+ filename = "something-something-v2-id2label.json"
+ else:
+ raise ValueError("Model name should either contain 'kinetics' or 'ssv2' in case it's fine-tuned.")
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ return config
+
+
+def rename_key(name):
+ if "encoder." in name:
+ name = name.replace("encoder.", "")
+ if "cls_token" in name:
+ name = name.replace("cls_token", "videomae.embeddings.cls_token")
+ if "decoder_pos_embed" in name:
+ name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed")
+ if "pos_embed" in name and "decoder" not in name:
+ name = name.replace("pos_embed", "videomae.embeddings.position_embeddings")
+ if "patch_embed.proj" in name:
+ name = name.replace("patch_embed.proj", "videomae.embeddings.patch_embeddings.projection")
+ if "patch_embed.norm" in name:
+ name = name.replace("patch_embed.norm", "videomae.embeddings.norm")
+ if "decoder.blocks" in name:
+ name = name.replace("decoder.blocks", "decoder.decoder_layers")
+ if "blocks" in name:
+ name = name.replace("blocks", "videomae.encoder.layer")
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "attn" in name and "bias" not in name:
+ name = name.replace("attn", "attention.self")
+ if "attn" in name:
+ name = name.replace("attn", "attention.attention")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+ if "decoder_embed" in name:
+ name = name.replace("decoder_embed", "decoder.decoder_embed")
+ if "decoder_norm" in name:
+ name = name.replace("decoder_norm", "decoder.decoder_norm")
+ if "decoder_pred" in name:
+ name = name.replace("decoder_pred", "decoder.decoder_pred")
+ if "norm.weight" in name and "decoder" not in name and "fc" not in name:
+ name = name.replace("norm.weight", "videomae.layernorm.weight")
+ if "norm.bias" in name and "decoder" not in name and "fc" not in name:
+ name = name.replace("norm.bias", "videomae.layernorm.bias")
+ if "head" in name and "decoder" not in name:
+ name = name.replace("head", "classifier")
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, config):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if key.startswith("encoder."):
+ key = key.replace("encoder.", "")
+
+ if "qkv" in key:
+ key_split = key.split(".")
+ if key.startswith("decoder.blocks"):
+ dim = config.decoder_hidden_size
+ layer_num = int(key_split[2])
+ prefix = "decoder.decoder_layers."
+ if "weight" in key:
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
+ else:
+ dim = config.hidden_size
+ layer_num = int(key_split[1])
+ prefix = "videomae.encoder.layer."
+ if "weight" in key:
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
+ else:
+ orig_state_dict[rename_key(key)] = val
+
+ return orig_state_dict
+
+
+# We will verify our results on a video of eating spaghetti
+# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
+def prepare_video():
+ file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy")
+ video = np.load(file)
+ return list(video)
+
+
+def convert_videomae_checkpoint(checkpoint_url, pytorch_dump_folder_path, model_name, push_to_hub):
+ config = get_videomae_config(model_name)
+
+ if "finetuned" in model_name:
+ model = VideoMAEForVideoClassification(config)
+ else:
+ model = VideoMAEForPreTraining(config)
+
+ # download original checkpoint, hosted on Google Drive
+ output = "pytorch_model.bin"
+ gdown.cached_download(checkpoint_url, output, quiet=False)
+ files = torch.load(output, map_location="cpu")
+ if "model" in files:
+ state_dict = files["model"]
+ else:
+ state_dict = files["module"]
+ new_state_dict = convert_state_dict(state_dict, config)
+
+ model.load_state_dict(new_state_dict)
+ model.eval()
+
+ # verify model on basic input
+ feature_extractor = VideoMAEFeatureExtractor(image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
+ video = prepare_video()
+ inputs = feature_extractor(video, return_tensors="pt")
+
+ if "finetuned" not in model_name:
+ local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
+ inputs["bool_masked_pos"] = torch.load(local_path)
+
+ outputs = model(**inputs)
+ logits = outputs.logits
+
+ model_names = [
+ # Kinetics-400 checkpoints (short = pretrained only for 800 epochs instead of 1600)
+ "videomae-base-short",
+ "videomae-base-short-finetuned-kinetics",
+ "videomae-base",
+ "videomae-base-finetuned-kinetics",
+ "videomae-large",
+ "videomae-large-finetuned-kinetics",
+ # Something-Something-v2 checkpoints (short = pretrained only for 800 epochs instead of 2400)
+ "videomae-base-short-ssv2",
+ "videomae-base-short-finetuned-ssv2",
+ "videomae-base-ssv2",
+ "videomae-base-finetuned-ssv2",
+ ]
+
+ # NOTE: logits were tested with image_mean and image_std equal to [0.5, 0.5, 0.5] and [0.5, 0.5, 0.5]
+ if model_name == "videomae-base":
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor([[0.7739, 0.7968, 0.7089], [0.6701, 0.7487, 0.6209], [0.4287, 0.5158, 0.4773]])
+ elif model_name == "videomae-base-short":
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor([[0.7994, 0.9612, 0.8508], [0.7401, 0.8958, 0.8302], [0.5862, 0.7468, 0.7325]])
+ # we verified the loss both for normalized and unnormalized targets for this one
+ expected_loss = torch.tensor([0.5142]) if config.norm_pix_loss else torch.tensor([0.6469])
+ elif model_name == "videomae-large":
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor([[0.7149, 0.7997, 0.6966], [0.6768, 0.7869, 0.6948], [0.5139, 0.6221, 0.5605]])
+ elif model_name == "videomae-large-finetuned-kinetics":
+ expected_shape = torch.Size([1, 400])
+ expected_slice = torch.tensor([0.0771, 0.0011, -0.3625])
+ elif model_name == "videomae-base-short-finetuned-kinetics":
+ expected_shape = torch.Size([1, 400])
+ expected_slice = torch.tensor([0.6588, 0.0990, -0.2493])
+ elif model_name == "videomae-base-finetuned-kinetics":
+ expected_shape = torch.Size([1, 400])
+ expected_slice = torch.tensor([0.3669, -0.0688, -0.2421])
+ elif model_name == "videomae-base-short-ssv2":
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor([[0.4712, 0.5296, 0.5786], [0.2278, 0.2729, 0.4026], [0.0352, 0.0730, 0.2506]])
+ elif model_name == "videomae-base-short-finetuned-ssv2":
+ expected_shape = torch.Size([1, 174])
+ expected_slice = torch.tensor([-0.0537, -0.1539, -0.3266])
+ elif model_name == "videomae-base-ssv2":
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor([[0.8131, 0.8727, 0.8546], [0.7366, 0.9377, 0.8870], [0.5935, 0.8874, 0.8564]])
+ elif model_name == "videomae-base-finetuned-ssv2":
+ expected_shape = torch.Size([1, 174])
+ expected_slice = torch.tensor([0.1961, -0.8337, -0.6389])
+ else:
+ raise ValueError(f"Model name not supported. Should be one of {model_names}")
+
+ # verify logits
+ assert logits.shape == expected_shape
+ if "finetuned" in model_name:
+ assert torch.allclose(logits[0, :3], expected_slice, atol=1e-4)
+ else:
+ print("Logits:", logits[0, :3, :3])
+ assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)
+ print("Logits ok!")
+
+ # verify loss, if applicable
+ if model_name == "videomae-base-short":
+ loss = outputs.loss
+ assert torch.allclose(loss, expected_loss, atol=1e-4)
+ print("Loss ok!")
+
+ if pytorch_dump_folder_path is not None:
+ print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing to the hub...")
+ model.push_to_hub(model_name, organization="nielsr")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://drive.google.com/u/1/uc?id=1tEhLyskjb755TJ65ptsrafUG2llSwQE1&export=download&confirm=t&uuid=aa3276eb-fb7e-482a-adec-dc7171df14c4",
+ type=str,
+ help=(
+ "URL of the original PyTorch checkpoint (on Google Drive) you'd like to convert. Should be a direct"
+ " download link."
+ ),
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="/Users/nielsrogge/Documents/VideoMAE/Test",
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument("--model_name", default="videomae-base", type=str, help="Name of the model.")
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the š¤ hub."
+ )
+
+ args = parser.parse_args()
+ convert_videomae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.model_name, args.push_to_hub)
diff --git a/src/transformers/models/videomae/feature_extraction_videomae.py b/src/transformers/models/videomae/feature_extraction_videomae.py
new file mode 100644
index 000000000000..132dabda8c68
--- /dev/null
+++ b/src/transformers/models/videomae/feature_extraction_videomae.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for VideoMAE."""
+
+from typing import Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
+from ...utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a VideoMAE feature extractor. This feature extractor can be used to prepare videos for the model.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the shorter edge of the input to a certain `size`.
+ size (`int`, *optional*, defaults to 224):
+ Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the input to a certain `size`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BILINEAR,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+
+ def resize_video(self, video, size, resample="bilinear"):
+ return [self.resize(frame, size, resample, default_to_square=False) for frame in video]
+
+ def crop_video(self, video, size):
+ return [self.center_crop(frame, size) for frame in video]
+
+ def normalize_video(self, video, mean, std):
+ # video can be a list of PIL images, list of NumPy arrays or list of PyTorch tensors
+ # first: convert to list of NumPy arrays
+ video = [self.to_numpy_array(frame) for frame in video]
+
+ # second: stack to get (num_frames, num_channels, height, width)
+ video = np.stack(video, axis=0)
+
+ # third: normalize
+ if not isinstance(mean, np.ndarray):
+ mean = np.array(mean).astype(video.dtype)
+ if not isinstance(std, np.ndarray):
+ std = np.array(std).astype(video.dtype)
+
+ return (video - mean[None, :, None, None]) / std[None, :, None, None]
+
+ def __call__(
+ self, videos: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several video(s).
+
+
+
+ NumPy arrays are converted to PIL images when resizing, so the most efficient is to pass PIL images.
+
+
+
+ Args:
+ videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
+ `List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
+ of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
+ each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
+ channels.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, num_frames,
+ height, width).
+ """
+ # Input type checking for clearer error
+ valid_videos = False
+ is_batched = False
+
+ # Check that videos have a valid type
+ if isinstance(videos, (list, tuple)):
+ if isinstance(videos[0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0]):
+ valid_videos = True
+ elif isinstance(videos[0], (list, tuple)) and (
+ isinstance(videos[0][0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0][0])
+ ):
+ valid_videos = True
+ is_batched = True
+
+ if not valid_videos:
+ raise ValueError(
+ "Videos must of type `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]` (single"
+ " example), `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]` (batch"
+ " of examples)."
+ )
+
+ if not is_batched:
+ videos = [videos]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None:
+ videos = [self.resize_video(video, size=self.size, resample=self.resample) for video in videos]
+ if self.do_center_crop and self.size is not None:
+ videos = [self.crop_video(video, size=self.size) for video in videos]
+ if self.do_normalize:
+ videos = [self.normalize_video(video, mean=self.image_mean, std=self.image_std) for video in videos]
+
+ # return as BatchFeature
+ data = {"pixel_values": videos}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py
new file mode 100644
index 000000000000..a807ed7208fc
--- /dev/null
+++ b/src/transformers/models/videomae/modeling_videomae.py
@@ -0,0 +1,1039 @@
+# coding=utf-8
+# Copyright 2022 Multimedia Computing Group, Nanjing University and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch VideoMAE (masked autoencoder) model."""
+
+
+import collections.abc
+import math
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Optional, Set, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .configuration_videomae import VideoMAEConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "VideoMAEConfig"
+_CHECKPOINT_FOR_DOC = "MCG-NJU/videomae-base"
+
+VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "MCG-NJU/videomae-base",
+ # See all VideoMAE models at https://huggingface.co/models?filter=videomae
+]
+
+
+@dataclass
+class VideoMAEDecoderOutput(ModelOutput):
+ """
+ Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ Pixel reconstruction logits.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class VideoMAEForPreTrainingOutput(ModelOutput):
+ """
+ Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`):
+ Pixel reconstruction loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ Pixel reconstruction logits.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# sin-cos position encoding
+# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
+def get_sinusoid_encoding_table(n_position, d_hid):
+ """Sinusoid position encoding table"""
+ # TODO: make it with torch instead of numpy
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+class VideoMAEEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings.
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = VideoMAEPatchEmbeddings(config)
+ self.num_patches = self.patch_embeddings.num_patches
+ # fixed sin-cos embedding
+ self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
+ self.config = config
+
+ def forward(self, pixel_values, bool_masked_pos):
+ # create patch embeddings
+ embeddings = self.patch_embeddings(pixel_values)
+
+ # add position embeddings
+ embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
+
+ # only keep visible patches
+ # ~bool_masked_pos means visible
+ if bool_masked_pos is not None:
+ batch_size, _, num_channels = embeddings.shape
+ embeddings = embeddings[~bool_masked_pos]
+ embeddings = embeddings.reshape(batch_size, -1, num_channels)
+
+ return embeddings
+
+
+class VideoMAEPatchEmbeddings(nn.Module):
+ """
+ Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
+ height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
+
+ The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
+ patch_size).
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ image_size = config.image_size
+ patch_size = config.patch_size
+ num_channels = config.num_channels
+ hidden_size = config.hidden_size
+ num_frames = config.num_frames
+ tubelet_size = config.tubelet_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.tubelet_size = int(tubelet_size)
+ num_patches = (
+ (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
+ )
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.projection = nn.Conv3d(
+ in_channels=num_channels,
+ out_channels=hidden_size,
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
+ stride=(self.tubelet_size, patch_size[0], patch_size[1]),
+ )
+
+ def forward(self, pixel_values):
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ # permute to (batch_size, num_channels, num_frames, height, width)
+ pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class VideoMAESelfAttention(nn.Module):
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+
+ if config.qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(self.all_head_size))
+ self.v_bias = nn.Parameter(torch.zeros(self.all_head_size))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+
+ k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
+ keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
+ values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
+ queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
+
+ key_layer = self.transpose_for_scores(keys)
+ value_layer = self.transpose_for_scores(values)
+ query_layer = self.transpose_for_scores(queries)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
+class VideoMAESelfOutput(nn.Module):
+ """
+ The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VideoMAE
+class VideoMAEAttention(nn.Module):
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.attention = VideoMAESelfAttention(config)
+ self.output = VideoMAESelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
+class VideoMAEIntermediate(nn.Module):
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->VideoMAE
+class VideoMAEOutput(nn.Module):
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE
+class VideoMAELayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = VideoMAEAttention(config)
+ self.intermediate = VideoMAEIntermediate(config)
+ self.output = VideoMAEOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in VideoMAE, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in VideoMAE, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VideoMAE
+class VideoMAEEncoder(nn.Module):
+ def __init__(self, config: VideoMAEConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class VideoMAEPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = VideoMAEConfig
+ base_model_prefix = "videomae"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, VideoMAEEncoder):
+ module.gradient_checkpointing = value
+
+
+VIDEOMAE_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`VideoMAEConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VIDEOMAE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`VideoMAEFeatureExtractor`]. See
+ [`VideoMAEFeatureExtractor.__call__`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare VideoMAE Model transformer outputting raw hidden-states without any specific head on top.",
+ VIDEOMAE_START_DOCSTRING,
+)
+class VideoMAEModel(VideoMAEPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = VideoMAEEmbeddings(config)
+ self.encoder = VideoMAEEncoder(config)
+
+ if config.use_mean_pooling:
+ self.layernorm = None
+ else:
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values,
+ bool_masked_pos=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from decord import VideoReader, cpu
+ >>> import numpy as np
+
+ >>> from transformers import VideoMAEFeatureExtractor, VideoMAEModel
+ >>> from huggingface_hub import hf_hub_download
+
+
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+ ... converted_len = int(clip_len * frame_sample_rate)
+ ... end_idx = np.random.randint(converted_len, seg_len)
+ ... start_idx = end_idx - converted_len
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+ ... return indices
+
+
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
+
+ >>> # sample 16 frames
+ >>> vr.seek(0)
+ >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(vr))
+ >>> buffer = vr.get_batch(indices).asnumpy()
+
+ >>> # create a list of NumPy arrays
+ >>> video = [buffer[i] for i in range(buffer.shape[0])]
+
+ >>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
+ >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
+
+ >>> # prepare video for the model
+ >>> inputs = feature_extractor(video, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 1568, 768]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ if self.layernorm is not None:
+ sequence_output = self.layernorm(sequence_output)
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class VideoMAEDecoder(nn.Module):
+ def __init__(self, config, num_patches):
+ super().__init__()
+
+ decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2
+
+ decoder_config = deepcopy(config)
+ decoder_config.hidden_size = config.decoder_hidden_size
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
+ decoder_config.intermediate_size = config.decoder_intermediate_size
+ self.decoder_layers = nn.ModuleList(
+ [VideoMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
+ )
+
+ self.norm = nn.LayerNorm(config.decoder_hidden_size)
+ self.head = (
+ nn.Linear(config.decoder_hidden_size, decoder_num_labels) if decoder_num_labels > 0 else nn.Identity()
+ )
+
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ return_token_num,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ # apply Transformer layers (blocks)
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ for i, layer_module in enumerate(self.decoder_layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ None,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if return_token_num > 0:
+ hidden_states = hidden_states[:, -return_token_num:]
+
+ # predictor projection
+ hidden_states = self.norm(hidden_states)
+ logits = self.head(hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
+ return VideoMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
+
+
+@add_start_docstrings(
+ "The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.",
+ VIDEOMAE_START_DOCSTRING,
+)
+class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.videomae = VideoMAEModel(config)
+
+ self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=False)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
+ self.position_embeddings = get_sinusoid_encoding_table(
+ self.videomae.embeddings.num_patches, config.decoder_hidden_size
+ )
+
+ self.decoder = VideoMAEDecoder(config, num_patches=self.videomae.embeddings.num_patches)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=VideoMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values,
+ bool_masked_pos,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import VideoMAEFeatureExtractor, VideoMAEForPreTraining
+ >>> import numpy as np
+ >>> import torch
+
+ >>> num_frames = 16
+ >>> video = list(np.random.randn(16, 3, 224, 224))
+
+ >>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
+ >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")
+
+ >>> pixel_values = feature_extractor(video, return_tensors="pt").pixel_values
+
+ >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
+ >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
+ >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss = outputs.loss
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.videomae(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ sequence_output = self.encoder_to_decoder(
+ sequence_output
+ ) # [batch_size, num_visible_patches, decoder_hidden_size]
+ batch_size, seq_len, num_channels = sequence_output.shape
+
+ # we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
+ if bool_masked_pos is None:
+ raise ValueError("One must provided a boolean mask ")
+ expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
+ expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
+ pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
+ pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
+
+ # [batch_size, num_patches, decoder_hidden_size]
+ x_full = torch.cat([sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1)
+
+ # [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
+ decoder_outputs = self.decoder(x_full, pos_emb_mask.shape[1])
+ logits = decoder_outputs.logits
+
+ loss = None
+ with torch.no_grad():
+ # calculate the labels to be predicted
+ # first, unnormalize the frames
+ device = pixel_values.device
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None]
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None]
+ frames = pixel_values * std + mean # in [0, 1]
+
+ batch_size, time, num_channels, height, width = frames.shape
+ tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
+ if self.config.norm_pix_loss:
+ # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
+ frames = frames.view(
+ batch_size,
+ time // tubelet_size,
+ tubelet_size,
+ num_channels,
+ height // patch_size,
+ patch_size,
+ width // patch_size,
+ patch_size,
+ )
+ # step 2: move dimensions to concatenate:
+ frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
+ # step 3: concatenate:
+ frames = frames.view(
+ batch_size,
+ time // tubelet_size * height // patch_size * width // patch_size,
+ tubelet_size * patch_size * patch_size,
+ num_channels,
+ )
+ # step 4: normalize. The authors find that the mean is about 0.48 and standard deviation is about 0.08.
+ frames_norm = (frames - frames.mean(dim=-2, keepdim=True)) / (
+ frames.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
+ )
+ # step 5: reshape to (batch_size, T//ts * H//ps * W//ps, ts * ps * ps * C)
+ videos_patch = frames_norm.view(
+ batch_size,
+ time // tubelet_size * height // patch_size * width // patch_size,
+ tubelet_size * patch_size * patch_size * num_channels,
+ )
+ else:
+ # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
+ frames = frames.view(
+ batch_size,
+ time // tubelet_size,
+ tubelet_size,
+ num_channels,
+ height // patch_size,
+ patch_size,
+ width // patch_size,
+ patch_size,
+ )
+ # step 2: move dimensions to concatenate: (batch_size, T//ts, H//ps, W//ps, ts, ps, ps, C)
+ frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
+ # step 3: concatenate
+ videos_patch = frames.view(
+ batch_size,
+ time // tubelet_size * height // patch_size * width // patch_size,
+ tubelet_size * patch_size * patch_size * num_channels,
+ )
+
+ batch_size, _, num_channels = videos_patch.shape
+ labels = videos_patch[bool_masked_pos].reshape(batch_size, -1, num_channels)
+
+ loss_fct = MSELoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return VideoMAEForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """VideoMAE Model transformer with a video classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.""",
+ VIDEOMAE_START_DOCSTRING,
+)
+class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.videomae = VideoMAEModel(config)
+
+ # Classifier head
+ self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from decord import VideoReader, cpu
+ >>> import torch
+
+ >>> from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
+ >>> from huggingface_hub import hf_hub_download
+
+
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+ ... converted_len = int(clip_len * frame_sample_rate)
+ ... end_idx = np.random.randint(converted_len, seg_len)
+ ... start_idx = end_idx - converted_len
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+ ... return indices
+
+
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> vr = VideoReader(file_path, num_threads=1, ctx=cpu(0))
+
+ >>> # sample 16 frames
+ >>> vr.seek(0)
+ >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(vr))
+ >>> buffer = vr.get_batch(indices).asnumpy()
+
+ >>> # create a list of NumPy arrays
+ >>> video = [buffer[i] for i in range(buffer.shape[0])]
+
+ >>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
+ >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
+
+ >>> inputs = feature_extractor(video, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+ ... logits = outputs.logits
+
+ >>> # model predicts one of the 400 Kinetics-400 classes
+ >>> predicted_label = logits.argmax(-1).item()
+ >>> print(model.config.id2label[predicted_label])
+ eating spaghetti
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.videomae(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ if self.fc_norm is not None:
+ sequence_output = self.fc_norm(sequence_output.mean(1))
+ else:
+ sequence_output = sequence_output[:, 0]
+
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/vilt/__init__.py b/src/transformers/models/vilt/__init__.py
index 7aa27b98deca..d05318202bcd 100644
--- a/src/transformers/models/vilt/__init__.py
+++ b/src/transformers/models/vilt/__init__.py
@@ -18,22 +18,31 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"],
-}
+_import_structure = {"configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_vilt"] = ["ViltFeatureExtractor"]
_import_structure["processing_vilt"] = ["ViltProcessor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vilt"] = [
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval",
"ViltForImagesAndTextClassification",
+ "ViltForTokenClassification",
"ViltForMaskedLM",
"ViltForQuestionAnswering",
"ViltLayer",
@@ -45,17 +54,28 @@
if TYPE_CHECKING:
from .configuration_vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_vilt import ViltFeatureExtractor
from .processing_vilt import ViltProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vilt import (
VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViltForImageAndTextRetrieval,
ViltForImagesAndTextClassification,
ViltForMaskedLM,
ViltForQuestionAnswering,
+ ViltForTokenClassification,
ViltLayer,
ViltModel,
ViltPreTrainedModel,
diff --git a/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
index 9de026ebec86..3a186e1d2d91 100644
--- a/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
+++ b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
@@ -231,7 +231,10 @@ def convert_vilt_checkpoint(checkpoint_url, pytorch_dump_folder_path):
if nlvr_model:
image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
- text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ text = (
+ "The left image contains twice the number of dogs as the right image, and at least two dogs in total are"
+ " standing."
+ )
encoding_1 = processor(image1, text, return_tensors="pt")
encoding_2 = processor(image2, text, return_tensors="pt")
outputs = model(
diff --git a/src/transformers/models/vilt/feature_extraction_vilt.py b/src/transformers/models/vilt/feature_extraction_vilt.py
index 7fdd138750ac..0c64c10959bd 100644
--- a/src/transformers/models/vilt/feature_extraction_vilt.py
+++ b/src/transformers/models/vilt/feature_extraction_vilt.py
@@ -33,6 +33,7 @@
if is_torch_available():
import torch
+
logger = logging.get_logger(__name__)
diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py
index 74759be40477..308358850c98 100755
--- a/src/transformers/models/vilt/modeling_vilt.py
+++ b/src/transformers/models/vilt/modeling_vilt.py
@@ -21,7 +21,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -32,15 +31,27 @@
MaskedLMOutput,
ModelOutput,
SequenceClassifierOutput,
+ TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ find_pruneable_heads_and_indices,
+ is_torch_greater_or_equal_than_1_10,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vilt import ViltConfig
logger = logging.get_logger(__name__)
+if not is_torch_greater_or_equal_than_1_10:
+ logger.warning(
+ f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use "
+ "ViltModel. Please upgrade torch."
+ )
+
_CONFIG_FOR_DOC = "ViltConfig"
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm"
@@ -76,13 +87,6 @@ class ViltForImagesAndTextClassificationOutput(ModelOutput):
attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
-# Copied from transformers.models.vit.modeling_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
class ViltEmbeddings(nn.Module):
"""
Construct the text and patch embeddings.
@@ -99,12 +103,7 @@ def __init__(self, config):
self.text_embeddings = TextEmbeddings(config)
# patch embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- self.patch_embeddings = PatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.hidden_size,
- )
+ self.patch_embeddings = ViltPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
# modality type (text/patch) embeddings
@@ -256,7 +255,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
@@ -298,26 +297,32 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs
return embeddings
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class PatchEmbeddings(nn.Module):
+class ViltPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
"""
- def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
x = self.projection(pixel_values)
return x
@@ -1402,3 +1407,90 @@ def forward(
hidden_states=hidden_states,
attentions=attentions,
)
+
+
+@add_start_docstrings(
+ """
+ ViLT Model with a token classification head on top (a linear layer on top of the final hidden-states of the text
+ tokens) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ VILT_START_DOCSTRING,
+)
+class ViltForTokenClassification(ViltPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.vilt = ViltModel(config, add_pooling_layer=False)
+
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Returns:
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vilt(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output[:, :text_input_size])
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/vision_encoder_decoder/__init__.py b/src/transformers/models/vision_encoder_decoder/__init__.py
index 0757f15ec819..5d501b8feb83 100644
--- a/src/transformers/models/vision_encoder_decoder/__init__.py
+++ b/src/transformers/models/vision_encoder_decoder/__init__.py
@@ -18,32 +18,66 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
-}
+_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else:
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
index e0478f1e13a5..7042b2548deb 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
@@ -301,10 +301,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -832,10 +832,9 @@ def from_encoder_decoder_pretrained(
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
index 6bbf51409103..682faa3825c5 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
@@ -43,10 +43,10 @@
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
DEPRECATION_WARNING = (
- "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
- "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
- "a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
- "need to pass them yourself anymore."
+ "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
)
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
@@ -202,10 +202,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -222,11 +222,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -337,10 +339,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pt = kwargs.pop("from_pt", False)
if from_pt:
raise ValueError(
- "Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. "
- "Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, "
- "create the encoder and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. "
- "Check `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
+ "Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. Use"
+ " a tensorflow checkpoint instead. If only the pytorch checkpoints are available, create the encoder"
+ " and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. Check"
+ " `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@@ -469,10 +471,9 @@ def from_encoder_decoder_pretrained(
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -662,13 +663,13 @@ def call(
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
- past_key_values = None
- if decoder_inputs["use_cache"]:
- past_key_values = decoder_outputs[1]
- # The starting index of the remaining elements in `decoder_outputs`
- start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
+ if not return_dict:
+ past_key_values = None
+ if use_cache:
+ past_key_values = decoder_outputs[1]
+ # The starting index of the remaining elements in `decoder_outputs`
+ start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
- if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
@@ -678,7 +679,7 @@ def call(
return TFSeq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
- past_key_values=past_key_values,
+ past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
index 37072270a567..d2c4ae6b18cf 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
@@ -173,10 +173,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -195,11 +195,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -369,10 +371,9 @@ def from_encoder_decoder_pretrained(
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -546,8 +547,8 @@ def prepare_inputs_for_generation(
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported."
- "Please use the respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/vision_text_dual_encoder/__init__.py b/src/transformers/models/vision_text_dual_encoder/__init__.py
index 4e705cd03721..89aa78c83112 100644
--- a/src/transformers/models/vision_text_dual_encoder/__init__.py
+++ b/src/transformers/models/vision_text_dual_encoder/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
@@ -27,11 +27,21 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]
@@ -39,10 +49,20 @@
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
index 4cf6c59882aa..aac1b0e8e93d 100644
--- a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
+++ b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
@@ -536,9 +536,9 @@ def from_vision_text_pretrained(
# the projection layers are always newly initialized when loading the model
# using pre-trained vision and text model.
logger.warning(
- "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection', 'kernel'), ('logit_scale',)]` "
- "are newly initialized. You should probably TRAIN this model on a down-stream task "
- "to be able to use it for predictions and inference."
+ "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection',"
+ " 'kernel'), ('logit_scale',)]` are newly initialized. You should probably TRAIN this model on a"
+ " down-stream task to be able to use it for predictions and inference."
)
return model
diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
index e13c9ca7ef8f..66340deaf492 100755
--- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
+++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
@@ -530,9 +530,9 @@ def from_vision_text_pretrained(
# the projection layers are always newly initialized when loading the model
# using pre-trained vision and text model.
logger.warning(
- "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` "
- "are newly initialized. You should probably TRAIN this model on a down-stream task "
- "to be able to use it for predictions and inference."
+ "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight',"
+ " 'logit_scale']` are newly initialized. You should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
)
return model
diff --git a/src/transformers/models/visual_bert/__init__.py b/src/transformers/models/visual_bert/__init__.py
index 444929e15179..f7a5390d1348 100644
--- a/src/transformers/models/visual_bert/__init__.py
+++ b/src/transformers/models/visual_bert/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
-}
+_import_structure = {"configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_visual_bert"] = [
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"VisualBertForMultipleChoice",
@@ -41,7 +44,12 @@
if TYPE_CHECKING:
from .configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
diff --git a/src/transformers/models/visual_bert/configuration_visual_bert.py b/src/transformers/models/visual_bert/configuration_visual_bert.py
index d4992d5267f8..60a3692644d7 100644
--- a/src/transformers/models/visual_bert/configuration_visual_bert.py
+++ b/src/transformers/models/visual_bert/configuration_visual_bert.py
@@ -23,13 +23,19 @@
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"uclanlp/visualbert-vqa": "https://huggingface.co/uclanlp/visualbert-vqa/resolve/main/config.json",
"uclanlp/visualbert-vqa-pre": "https://huggingface.co/uclanlp/visualbert-vqa-pre/resolve/main/config.json",
- "uclanlp/visualbert-vqa-coco-pre": "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json",
+ "uclanlp/visualbert-vqa-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json"
+ ),
"uclanlp/visualbert-vcr": "https://huggingface.co/uclanlp/visualbert-vcr/resolve/main/config.json",
"uclanlp/visualbert-vcr-pre": "https://huggingface.co/uclanlp/visualbert-vcr-pre/resolve/main/config.json",
- "uclanlp/visualbert-vcr-coco-pre": "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json",
+ "uclanlp/visualbert-vcr-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json"
+ ),
"uclanlp/visualbert-nlvr2": "https://huggingface.co/uclanlp/visualbert-nlvr2/resolve/main/config.json",
"uclanlp/visualbert-nlvr2-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-pre/resolve/main/config.json",
- "uclanlp/visualbert-nlvr2-coco-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
+ "uclanlp/visualbert-nlvr2-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
+ )
# See all VisualBERT models at https://huggingface.co/models?filter=visual_bert
}
diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py
index 643411ee7f32..983efada283b 100755
--- a/src/transformers/models/visual_bert/modeling_visual_bert.py
+++ b/src/transformers/models/visual_bert/modeling_visual_bert.py
@@ -158,7 +158,8 @@ def forward(
if (image_text_alignment_mask == 0).sum() != 0:
image_text_alignment_mask[image_text_alignment_mask == 0] = 1 # Avoid divide by zero error
logger.warning(
- "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero error."
+ "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero"
+ " error."
)
visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)
@@ -928,7 +929,7 @@ def forward(
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
- inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
+ inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
@@ -978,7 +979,7 @@ def forward(
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
- f"The labels provided should have same sequence length as total attention mask. "
+ "The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
@@ -991,7 +992,7 @@ def forward(
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
- f"The labels provided should have same sequence length as total attention mask. "
+ "The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
@@ -1432,7 +1433,7 @@ def transpose_for_scores(self, x):
def forward(self, query, key, attention_mask):
attention_mask = attention_mask.to(query.dtype)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
- attention_mask = (1.0 - attention_mask) * -10000.0
+ attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
mixed_query_layer = self.query(query)
mixed_key_layer = self.key(key)
diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py
index c0331e27d9d5..b30a9ec15d9d 100644
--- a/src/transformers/models/vit/__init__.py
+++ b/src/transformers/models/vit/__init__.py
@@ -17,17 +17,32 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
-
-
-_import_structure = {
- "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"],
-}
-
-if is_vision_available():
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vit"] = [
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTForImageClassification",
@@ -36,14 +51,24 @@
"ViTPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vit"] = [
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
@@ -53,10 +78,20 @@
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTOnnxConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_vit import ViTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vit import (
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTForImageClassification,
@@ -65,10 +100,20 @@
ViTPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py
index e603a6d4f8bc..e84fc6c25f4a 100644
--- a/src/transformers/models/vit/configuration_vit.py
+++ b/src/transformers/models/vit/configuration_vit.py
@@ -56,7 +56,7 @@ class ViTConfig(PretrainedConfig):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
- The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py
index eaa7c4225e8c..7a438abb0329 100644
--- a/src/transformers/models/vit/modeling_flax_vit.py
+++ b/src/transformers/models/vit/modeling_flax_vit.py
@@ -84,7 +84,7 @@
"""
-class FlaxPatchEmbeddings(nn.Module):
+class FlaxViTPatchEmbeddings(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@@ -94,6 +94,7 @@ def setup(self):
patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size)
self.num_patches = num_patches
+ self.num_channels = self.config.num_channels
self.projection = nn.Conv(
self.config.hidden_size,
kernel_size=(patch_size, patch_size),
@@ -104,9 +105,14 @@ def setup(self):
)
def __call__(self, pixel_values):
- x = self.projection(pixel_values)
- batch_size, _, _, channels = x.shape
- return jnp.reshape(x, (batch_size, -1, channels))
+ num_channels = pixel_values.shape[-1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values)
+ batch_size, _, _, channels = embeddings.shape
+ return jnp.reshape(embeddings, (batch_size, -1, channels))
class FlaxViTEmbeddings(nn.Module):
@@ -117,7 +123,7 @@ class FlaxViTEmbeddings(nn.Module):
def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
- self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype)
+ self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param(
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
@@ -143,7 +149,8 @@ class FlaxViTSelfAttention(nn.Module):
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
+ " {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -419,7 +426,7 @@ def __init__(
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
- input_shape = (1, config.image_size, config.image_size, 3)
+ input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py
index 9d478e968cfc..1db9cf58032d 100644
--- a/src/transformers/models/vit/modeling_tf_vit.py
+++ b/src/transformers/models/vit/modeling_tf_vit.py
@@ -52,19 +52,6 @@
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
-# Inspired by
-# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
-# From PyTorch internals
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-
-
class TFViTEmbeddings(tf.keras.layers.Layer):
"""
Construct the CLS token, position and patch embeddings.
@@ -74,7 +61,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
- self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
+ self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
@@ -103,19 +90,21 @@ def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
batch_size, seq_len, dim = shape_list(embeddings)
- npatch = seq_len - 1
+ num_patches = seq_len - 1
- _, N, _ = shape_list(self.position_embeddings)
- N -= 1
+ _, num_positions, _ = shape_list(self.position_embeddings)
+ num_positions -= 1
- if npatch == N and height == width:
+ if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
- images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)),
+ images=tf.reshape(
+ patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ ),
size=(h0, w0),
method="bicubic",
)
@@ -150,27 +139,31 @@ def call(
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class TFPatchEmbeddings(tf.keras.layers.Layer):
+class TFViTPatchEmbeddings(tf.keras.layers.Layer):
"""
- Image to Patch Embedding.
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
- image_size = to_2tuple(config.image_size)
- patch_size = to_2tuple(config.patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
- self.num_channels = config.num_channels
- self.embed_dim = config.hidden_size
+ self.num_channels = num_channels
self.config = config
self.projection = tf.keras.layers.Conv2D(
- filters=self.embed_dim,
+ filters=hidden_size,
kernel_size=patch_size,
- strides=self.patch_size,
+ strides=patch_size,
padding="valid",
data_format="channels_last",
use_bias=True,
@@ -183,11 +176,16 @@ def call(
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
if not interpolate_pos_encoding:
- if getattr(height, "numpy", None) and getattr(width, "numpy", None):
+ if tf.executing_eagerly():
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
@@ -200,9 +198,9 @@ def call(
# Change the 2D spatial dimensions to a single temporal dimension.
# shape = (batch_size, num_patches, out_channels=embed_dim)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
- x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
+ embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
- return x
+ return embeddings
class TFViTSelfAttention(tf.keras.layers.Layer):
diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py
index b2fc044fcb09..7017f232f0e9 100644
--- a/src/transformers/models/vit/modeling_vit.py
+++ b/src/transformers/models/vit/modeling_vit.py
@@ -59,23 +59,9 @@
]
-# Inspired by
-# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
-# From PyTorch internals
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-
-
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
-
"""
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
@@ -83,12 +69,7 @@ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
- self.patch_embeddings = PatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.hidden_size,
- )
+ self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
@@ -103,9 +84,9 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
- npatch = embeddings.shape[1] - 1
- N = self.position_embeddings.shape[1] - 1
- if npatch == N and height == width:
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
@@ -115,9 +96,11 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
- scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
@@ -134,9 +117,9 @@ def forward(
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
- mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
@@ -156,40 +139,42 @@ def forward(
return embeddings
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class PatchEmbeddings(nn.Module):
+class ViTPatchEmbeddings(nn.Module):
"""
- Image to Patch Embedding.
-
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
- def __init__(
- self,
- image_size: int = 224,
- patch_size: Union[int, Tuple[int, int]] = 16,
- num_channels: int = 3,
- embed_dim: int = 768,
- ):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
- x = self.projection(pixel_values).flatten(2).transpose(1, 2)
- return x
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
class ViTSelfAttention(nn.Module):
@@ -213,7 +198,7 @@ def __init__(self, config: ViTConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -245,7 +230,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -523,7 +508,7 @@ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_t
# Initialize weights and apply final processing
self.post_init()
- def get_input_embeddings(self) -> PatchEmbeddings:
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
@@ -612,7 +597,8 @@ def forward(self, hidden_states):
@add_start_docstrings(
- "ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM `__.",
+ "ViT Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://arxiv.org/abs/2111.09886).",
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
@@ -622,7 +608,11 @@ def __init__(self, config: ViTConfig) -> None:
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
- nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
+ nn.Conv2d(
+ in_channels=config.hidden_size,
+ out_channels=config.encoder_stride**2 * config.num_channels,
+ kernel_size=1,
+ ),
nn.PixelShuffle(config.encoder_stride),
)
@@ -687,7 +677,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
- height = width = int(sequence_length**0.5)
+ height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
diff --git a/src/transformers/models/vit_mae/__init__.py b/src/transformers/models/vit_mae/__init__.py
index cc3569b8b7f6..b785f7f6ee39 100644
--- a/src/transformers/models/vit_mae/__init__.py
+++ b/src/transformers/models/vit_mae/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"],
-}
+_import_structure = {"configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vit_mae"] = [
"VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTMAEForPreTraining",
@@ -33,7 +42,12 @@
"ViTMAEPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining",
"TFViTMAEModel",
@@ -43,7 +57,12 @@
if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vit_mae import (
VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTMAEForPreTraining,
@@ -52,7 +71,12 @@
ViTMAEPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
index f464b6665aff..d43bfa45b1fb 100644
--- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
+++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
@@ -84,7 +84,7 @@ class TFViTMAEDecoderOutput(ModelOutput):
Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
Args:
- logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
@@ -109,7 +109,7 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
Args:
loss (`tf.Tensor` of shape `(1,)`):
Pixel reconstruction loss.
- logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
@@ -133,13 +133,6 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
-# copied from transformers.models.vit.modeling_tf_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
@@ -212,7 +205,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
- self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
+ self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
self.num_patches = self.patch_embeddings.num_patches
self.config = config
@@ -297,30 +290,30 @@ def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
return embeddings, mask, ids_restore
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class TFPatchEmbeddings(tf.keras.layers.Layer):
+class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer):
"""
- Image to Patch Embedding.
-
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
- image_size = to_2tuple(config.image_size)
- patch_size = to_2tuple(config.patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
- self.num_channels = config.num_channels
- self.embed_dim = config.hidden_size
+ self.num_channels = num_channels
self.config = config
self.projection = tf.keras.layers.Conv2D(
- filters=self.embed_dim,
- kernel_size=self.patch_size,
- strides=self.patch_size,
+ filters=hidden_size,
+ kernel_size=patch_size,
+ strides=patch_size,
padding="valid",
data_format="channels_last",
kernel_initializer="glorot_uniform", # following torch.nn.Linear
@@ -330,10 +323,16 @@ def __init__(self, config: ViTMAEConfig, **kwargs):
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
- if getattr(height, "numpy", None) and getattr(width, "numpy", None):
+ if tf.executing_eagerly():
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the"
+ " configuration."
+ )
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
@@ -723,8 +722,8 @@ def serving(self, inputs):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
-
- return self.call(inputs)
+ output = self.call(inputs)
+ return self.serving_output(output)
VIT_MAE_START_DOCSTRING = r"""
@@ -843,6 +842,18 @@ def call(
return outputs
+ def serving_output(self, output: TFViTMAEModelOutput) -> TFViTMAEModelOutput:
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFViTMAEModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ mask=output.mask,
+ ids_restore=output.ids_restore,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
class TFViTMAEDecoder(tf.keras.layers.Layer):
def __init__(self, config, num_patches, **kwargs):
@@ -968,50 +979,110 @@ def get_input_embeddings(self):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
- def patchify(self, imgs):
+ def patchify(self, pixel_values):
"""
- imgs: (batch_size, height, width, 3) x: (batch_size, num_patches, patch_size**2 *3)
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
+ Pixel values.
+
+ Returns:
+ `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Patchified pixel values.
"""
- imgs = tf.cond(
- tf.math.equal(shape_list(imgs)[1], 3), lambda: tf.transpose(imgs, perm=(0, 2, 3, 1)), lambda: imgs
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
+ # make sure channels are last
+ pixel_values = tf.cond(
+ tf.math.equal(shape_list(pixel_values)[1], num_channels),
+ lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
+ lambda: pixel_values,
)
- p = self.vit.embeddings.patch_embeddings.patch_size[0]
- tf.debugging.assert_equal(shape_list(imgs)[1], shape_list(imgs)[2])
- tf.debugging.assert_equal(shape_list(imgs)[1] % p, 0)
+ # sanity checks
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[1],
+ shape_list(pixel_values)[2],
+ message="Make sure the pixel values have a squared size",
+ )
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[1] % patch_size,
+ 0,
+ message="Make sure the pixel values have a size that is divisible by the patch size",
+ )
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[3],
+ num_channels,
+ message=(
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
+ ),
+ )
- h = w = shape_list(imgs)[2] // p
- x = tf.reshape(imgs, (shape_list(imgs)[0], h, p, w, p, 3))
- x = tf.einsum("nhpwqc->nhwpqc", x)
- x = tf.reshape(x, (shape_list(imgs)[0], h * w, p**2 * 3))
- return x
+ # patchify
+ batch_size = shape_list(pixel_values)[0]
+ num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
+ patchified_pixel_values = tf.reshape(
+ pixel_values,
+ (batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
+ )
+ patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
+ patchified_pixel_values = tf.reshape(
+ patchified_pixel_values,
+ (batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
+ )
+ return patchified_pixel_values
- def unpatchify(self, x):
+ def unpatchify(self, patchified_pixel_values):
"""
- x: (batch_size, num_patches, patch_size**2 *3) imgs: (batch_size, height, width, 3)
+ Args:
+ patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Patchified pixel values.
+
+ Returns:
+ `tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
+ Pixel values.
"""
- p = self.vit.embeddings.patch_embeddings.patch_size[0]
- h = w = int(shape_list(x)[1] ** 0.5)
- tf.debugging.assert_equal(h * w, shape_list(x)[1])
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
+ num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
+ # sanity check
+ tf.debugging.assert_equal(
+ num_patches_one_direction * num_patches_one_direction,
+ shape_list(patchified_pixel_values)[1],
+ message="Make sure that the number of patches can be squared",
+ )
- x = tf.reshape(x, (shape_list(x)[0], h, w, p, p, 3))
- x = tf.einsum("nhwpqc->nhpwqc", x)
- imgs = tf.reshape(x, (shape_list(x)[0], h * p, h * p, 3))
- return imgs
+ # unpatchify
+ batch_size = shape_list(patchified_pixel_values)[0]
+ patchified_pixel_values = tf.reshape(
+ patchified_pixel_values,
+ (batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
+ )
+ patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
+ pixel_values = tf.reshape(
+ patchified_pixel_values,
+ (batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
+ )
+ return pixel_values
- def forward_loss(self, imgs, pred, mask):
+ def forward_loss(self, pixel_values, pred, mask):
"""
- imgs: [batch_size, height, width, 3] pred: [batch_size, num_patches, patch_size**2*3] mask: [N, L], 0 is keep,
- 1 is remove,
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
+ Pixel values.
+ pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Predicted pixel values.
+ mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (1) and which are not (0).
+
+ Returns:
+ `tf.Tensor`: Pixel reconstruction loss.
"""
- target = self.patchify(imgs)
+ target = self.patchify(pixel_values)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
- loss = tf.reduce_mean(loss, axis=-1) # [N, L], mean loss per patch
+ loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
return loss
@@ -1084,3 +1155,15 @@ def call(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+ def serving_output(self, output: TFViTMAEForPreTrainingOutput) -> TFViTMAEForPreTrainingOutput:
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFViTMAEForPreTrainingOutput(
+ logits=output.logits,
+ mask=output.mask,
+ ids_restore=output.ids_restore,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py
index 473ccd14feb0..0667bdd73c55 100755
--- a/src/transformers/models/vit_mae/modeling_vit_mae.py
+++ b/src/transformers/models/vit_mae/modeling_vit_mae.py
@@ -86,7 +86,7 @@ class ViTMAEDecoderOutput(ModelOutput):
Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
Args:
- logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
@@ -111,7 +111,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`):
Pixel reconstruction loss.
- logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
@@ -135,13 +135,6 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-# copied from transformers.models.vit.modeling_vit.to_2tuple ViT->ViTMAE
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
@@ -213,12 +206,7 @@ def __init__(self, config):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- self.patch_embeddings = PatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.hidden_size,
- )
+ self.patch_embeddings = ViTMAEPatchEmbeddings(config)
self.num_patches = self.patch_embeddings.num_patches
# fixed sin-cos embedding
self.position_embeddings = nn.Parameter(
@@ -291,27 +279,33 @@ def forward(self, pixel_values, noise=None):
return embeddings, mask, ids_restore
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class PatchEmbeddings(nn.Module):
+class ViTMAEPatchEmbeddings(nn.Module):
"""
- Image to Patch Embedding.
-
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
- def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
@@ -342,7 +336,7 @@ def __init__(self, config: ViTMAEConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -374,7 +368,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -868,37 +862,86 @@ class PreTrainedModel
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
- def patchify(self, imgs):
+ def patchify(self, pixel_values):
"""
- imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3)
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Patchified pixel values.
"""
- p = self.vit.embeddings.patch_embeddings.patch_size[0]
- assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
+ # sanity checks
+ if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
+ raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
+ if pixel_values.shape[1] != num_channels:
+ raise ValueError(
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
+ )
- h = w = imgs.shape[2] // p
- x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
- x = torch.einsum("nchpwq->nhwpqc", x)
- x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
- return x
+ # patchify
+ batch_size = pixel_values.shape[0]
+ num_patches_one_direction = pixel_values.shape[2] // patch_size
+ patchified_pixel_values = pixel_values.reshape(
+ batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
+ )
+ patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
+ patchified_pixel_values = patchified_pixel_values.reshape(
+ batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
+ )
+ return patchified_pixel_values
- def unpatchify(self, x):
+ def unpatchify(self, patchified_pixel_values):
"""
- x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W)
- """
- p = self.vit.embeddings.patch_embeddings.patch_size[0]
- h = w = int(x.shape[1] ** 0.5)
- assert h * w == x.shape[1]
+ Args:
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Patchified pixel values.
- x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
- x = torch.einsum("nhwpqc->nchpwq", x)
- imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
- return imgs
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
+ Pixel values.
+ """
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
+ num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
+ # sanity check
+ if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
+ raise ValueError("Make sure that the number of patches can be squared")
+
+ # unpatchify
+ batch_size = patchified_pixel_values.shape[0]
+ patchified_pixel_values = patchified_pixel_values.reshape(
+ batch_size,
+ num_patches_one_direction,
+ num_patches_one_direction,
+ patch_size,
+ patch_size,
+ num_channels,
+ )
+ patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
+ pixel_values = patchified_pixel_values.reshape(
+ batch_size,
+ num_channels,
+ num_patches_one_direction * patch_size,
+ num_patches_one_direction * patch_size,
+ )
+ return pixel_values
- def forward_loss(self, imgs, pred, mask):
+ def forward_loss(self, pixel_values, pred, mask):
"""
- imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove,
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
+ Predicted pixel values.
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (1) and which are not (0).
+
+ Returns:
+ `torch.FloatTensor`: Pixel reconstruction loss.
"""
- target = self.patchify(imgs)
+ target = self.patchify(pixel_values)
if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
@@ -958,8 +1001,8 @@ def forward(
ids_restore = outputs.ids_restore
mask = outputs.mask
- decoder_outputs = self.decoder(latent, ids_restore) # [N, L, p*p*3]
- logits = decoder_outputs.logits
+ decoder_outputs = self.decoder(latent, ids_restore)
+ logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
loss = self.forward_loss(pixel_values, logits, mask)
diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py
index 93783b668283..306c2197f4c3 100644
--- a/src/transformers/models/wav2vec2/__init__.py
+++ b/src/transformers/models/wav2vec2/__init__.py
@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
@@ -28,7 +34,12 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForAudioFrameClassification",
@@ -41,7 +52,12 @@
"Wav2Vec2PreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
@@ -49,7 +65,12 @@
"TFWav2Vec2PreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_wav2vec2"] = [
"FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining",
@@ -64,7 +85,12 @@
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
@@ -77,7 +103,12 @@
Wav2Vec2PreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
@@ -85,7 +116,12 @@
TFWav2Vec2PreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py
index f675f6799f66..6b96d9fc3f67 100644
--- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py
@@ -78,13 +78,13 @@ class Wav2Vec2Config(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
@@ -156,13 +156,13 @@ class Wav2Vec2Config(PretrainedConfig):
instance of [`Wav2Vec2ForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -288,10 +288,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
index db77a9ea1603..89ae3ad21c2e 100644
--- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
@@ -77,7 +77,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
if hf_shape != value.shape:
raise ValueError(
- f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
)
if weight_type == "weight":
@@ -148,14 +149,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
@@ -163,14 +166,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
index 595fb11192ad..14b1d688c9d7 100644
--- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
@@ -171,8 +171,9 @@ def __call__(
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}. "
- f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}."
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
index 7709e43ab955..68cce7d7d405 100644
--- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
@@ -120,7 +120,7 @@ def _compute_mask_indices(
CPU as part of the preprocessing during training.
Args:
- shape: the the shape for which to compute masks.
+ shape: the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
mask_prob:
probability for each token to be chosen as start of the span to be masked. this will be multiplied by
@@ -137,7 +137,8 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
@@ -186,7 +187,7 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attentio
batch_size, sequence_length, hidden_size = features_shape
if sequence_length <= 1:
raise ValueError(
- f"`features should have `sequence_length` > 1, but are of shape "
+ "`features should have `sequence_length` > 1, but are of shape "
f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
@@ -386,7 +387,8 @@ def setup(self):
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
else:
raise ValueError(
- f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
+ " 'layer']"
)
def __call__(self, hidden_states):
@@ -444,7 +446,8 @@ def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
dense = partial(
diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
index bac62f148ccb..854831e45a09 100644
--- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
@@ -25,7 +25,13 @@
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
-from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ booleans_processing,
+ get_initializer,
+ keras_serializable,
+ unpack_inputs,
+)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
@@ -133,12 +139,14 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[i]] = input
else:
raise ValueError(
- f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
+ f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[i]}."
)
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
- "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
+ " instead.",
FutureWarning,
)
@@ -146,7 +154,8 @@ def input_values_processing(func, config, input_values, **kwargs):
if "decoder_cached_states" in input_values:
warnings.warn(
- "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
@@ -166,7 +175,8 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[0]] = input_values
else:
raise ValueError(
- f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
+ f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[0]}."
)
for name in parameter_names:
@@ -234,7 +244,7 @@ def _compute_mask_indices(
Computes random mask spans for a given shape
Args:
- shape: the the shape for which to compute masks.
+ shape: the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob:
@@ -254,15 +264,17 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
- num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
- num_masked_spans = max(num_masked_spans, min_masks)
+ num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
+ num_masked_spans = tf.maximum(num_masked_spans, min_masks)
+ num_masked_spans = tf.cast(num_masked_spans, tf.int32)
# make sure num masked indices <= sequence_length
- if num_masked_spans * mask_length > sequence_length:
- num_masked_spans = sequence_length // mask_length
+ num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
+ num_masked_spans = tf.squeeze(num_masked_spans)
# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
@@ -286,13 +298,14 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
- tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
+ tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)
return spec_aug_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -441,9 +454,11 @@ def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
- "Axis " + str(self.axis) + " of "
- "input tensor should have a defined dimension "
- "but the layer received an input with shape " + str(input_shape) + "."
+ "Axis "
+ + str(self.axis)
+ + " of input tensor should have a defined dimension but the layer received an input with shape "
+ + str(input_shape)
+ + "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
@@ -457,22 +472,27 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") cannot be "
- "more than the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") cannot be more than the number of channels ("
+ + str(dim)
+ + ")."
)
if dim % self.groups != 0:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") must be a "
- "multiple of the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") must be a multiple of the number of channels ("
+ + str(dim)
+ + ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
- "You are trying to normalize your batch axis. Do you want to "
- "use tf.layer.batch_normalization instead"
+ "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
@@ -838,7 +858,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -848,7 +871,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -864,7 +890,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -881,7 +910,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -1321,7 +1353,15 @@ def __init__(self, config, *inputs, **kwargs):
"to train/fine-tine this model, you need a GPU or a TPU"
)
- @tf.function
+ @tf.function(
+ input_signature=[
+ {
+ "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
+ }
+ ]
+ )
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)
@@ -1513,14 +1553,14 @@ def call(
return outputs
def serving_output(self, output):
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFWav2Vec2BaseModelOutput(
last_hidden_state=output.last_hidden_state,
extract_features=output.extract_features,
- hidden_states=hs,
- attentions=attns,
+ hidden_states=hidden_states,
+ attentions=attentions,
)
@@ -1555,6 +1595,7 @@ def freeze_feature_encoder(self):
"""
self.wav2vec2.feature_extractor.trainable = False
+ @unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
@@ -1609,9 +1650,8 @@ def call(
>>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
- >>> # wrap processor as target processor to encode labels
- >>> with processor.as_target_processor():
- ... labels = processor(transcription, return_tensors="tf").input_ids
+ >>> # Pass transcription as `text` to encode labels
+ >>> labels = processor(text=transcription, return_tensors="tf").input_ids
>>> loss = model(input_values, labels=labels).loss
```"""
@@ -1677,6 +1717,8 @@ def call(
loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss)
+
+ loss = tf.reshape(loss, (1,))
else:
loss = None
@@ -1692,6 +1734,6 @@ def call(
)
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
- hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
- return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
+ hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+ return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index f58ec9a3363e..9f6780800396 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -33,6 +33,8 @@
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
@@ -88,35 +90,6 @@
]
-@dataclass
-class Wav2Vec2BaseModelOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2BaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
"""
@@ -159,38 +132,6 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
@@ -293,7 +234,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -636,7 +577,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -652,7 +594,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -673,7 +616,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -801,10 +745,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -888,10 +834,12 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
- attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
@@ -1022,11 +970,8 @@ def forward(self, hidden_states, mask_time_indices=None):
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
- codevectors = (
- codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
- .sum(-2)
- .view(batch_size, sequence_length, -1)
- )
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
@@ -1470,13 +1415,12 @@ def forward(
```python
>>> import torch
- >>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
- >>> import soundfile as sf
- >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
- >>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
+ >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
@@ -1910,6 +1854,7 @@ def __init__(self, config):
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1953,6 +1898,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1985,12 +1931,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py
index 1470c254dc63..5763d4d59eea 100644
--- a/src/transformers/models/wav2vec2/processing_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py
@@ -43,6 +43,7 @@ class Wav2Vec2Processor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
@@ -70,7 +71,35 @@ def __call__(self, *args, **kwargs):
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
"""
- return self.current_processor(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
def pad(self, *args, **kwargs):
"""
@@ -79,7 +108,28 @@ def pad(self, *args, **kwargs):
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
"""
- return self.current_processor.pad(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor.pad(*args, **kwargs)
+
+ input_features = kwargs.pop("input_features", None)
+ labels = kwargs.pop("labels", None)
+ if len(args) > 0:
+ input_features = args[0]
+ args = args[1:]
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
def batch_decode(self, *args, **kwargs):
"""
@@ -101,6 +151,13 @@ def as_target_processor(self):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Wav2Vec2.
"""
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
index 53a6cfe1c07a..1e77959400e4 100644
--- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
@@ -61,7 +61,9 @@
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json",
},
"tokenizer_config_file": {
- "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer_config.json",
+ "facebook/wav2vec2-base-960h": (
+ "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer_config.json"
+ ),
},
}
@@ -601,7 +603,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
@@ -717,7 +719,9 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
},
"tokenizer_config_file": {
- "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
+ "facebook/wav2vec2-base-960h": (
+ "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json"
+ ),
},
}
model_input_names = ["input_values", "attention_mask"]
@@ -748,7 +752,8 @@ def __init__(
)
warnings.warn(
- "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
+ "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use"
+ " `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
FutureWarning,
)
@@ -917,6 +922,6 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
diff --git a/src/transformers/models/wav2vec2_conformer/__init__.py b/src/transformers/models/wav2vec2_conformer/__init__.py
new file mode 100644
index 000000000000..df9fe20e2571
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/__init__.py
@@ -0,0 +1,74 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_wav2vec2_conformer": [
+ "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2ConformerConfig",
+ ],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_wav2vec2_conformer"] = [
+ "WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ Wav2Vec2ConformerConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py
new file mode 100644
index 000000000000..9c5e4d205b9a
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py
@@ -0,0 +1,357 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Wav2Vec2Conformer model configuration"""
+
+import functools
+import operator
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/wav2vec2-conformer-large-rel-pos": (
+ "https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos/resolve/main/config.json"
+ ),
+}
+
+
+class Wav2Vec2ConformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Wav2Vec2ConformerModel`]. It is used to
+ instantiate an Wav2Vec2Conformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer
+ [facebook/wav2vec2-conformer-large-rel-pos](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*):
+ Vocabulary size of the Wav2Vec2Conformer model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`Wav2Vec2ConformerModel`]. Vocabulary size of the
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
+ method of [`Wav2Vec2ConformerModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`Wav2Vec2ConformerForCTC`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
+ convolutional layers.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for output of the feature encoder.
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for quantized feature encoder states.
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
+ *conv_dim*.
+ conv_bias (`bool`, *optional*, defaults to `False`):
+ Whether the 1D convolutional layers have a bias.
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
+ embeddings layer.
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+ Number of groups of 1D convolutional positional embeddings layer.
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
+ Recognition](https://arxiv.org/abs/1904.08779).
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
+ mask_time_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the time axis.
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
+ mask_time_min_masks''
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+ True`.
+ mask_feature_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the feature axis.
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+ step, irrespectively of `mask_feature_prob`. Only relevant if
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
+ Number of entries in each quantization codebook (group).
+ num_codevector_groups (`int`, *optional*, defaults to 2):
+ Number of codevector groups for product codevector quantization.
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
+ The temperature *kappa* in the contrastive loss.
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
+ num_negatives (`int`, *optional*, defaults to 100):
+ Number of negative samples for the contrastive loss.
+ codevector_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of the quantized feature vectors.
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of the final projection of both the quantized and the transformer features.
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
+ The weight of the codebook diversity loss component.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`Wav2Vec2ConformerForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`Wav2Vec2ConformerForCTC`].
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+ instance of [`Wav2Vec2ConformerForSequenceClassification`].
+ classifier_proj_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the projection before token mean-pooling for classification.
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
+ xvector_output_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of the *XVector* embedding vectors.
+ add_adapter (`bool`, *optional*, defaults to `False`):
+ Whether a convolutional network should be stacked on top of the Wav2Vec2Conformer Encoder. Can be very
+ useful for warm-starting Wav2Vec2Conformer for SpeechEncoderDecoder models.
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ adapter_stride (`int`, *optional*, defaults to 2):
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ num_adapter_layers (`int`, *optional*, defaults to 3):
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
+ True`.
+ output_hidden_size (`int`, *optional*):
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
+ if `add_adapter is True`.
+ position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
+ Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
+ `None` no relative position embedding is applied.
+ rotary_embedding_base (`int`, *optional*, defaults to 10000):
+ If `"rotary"` position embeddings are used, defines the size of the embedding base.
+ max_source_positions (`int`, *optional*, defaults to 5000):
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
+ conv_depthwise_kernel_size (`int`, defaults to 31):
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks.
+ conformer_conv_dropout (`float`, defaults to 0.1):
+ The dropout probability for all convolutional layers in Conformer blocks.
+
+ Example:
+
+ ```python
+ >>> from transformers import Wav2Vec2ConformerModel, Wav2Vec2ConformerConfig
+
+ >>> # Initializing a Wav2Vec2Conformer facebook/wav2vec2-conformer-large-rel-pos style configuration
+ >>> configuration = Wav2Vec2ConformerConfig()
+
+ >>> # Initializing a model from the facebook/wav2vec2-conformer-large-rel-pos style configuration
+ >>> model = Wav2Vec2ConformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "wav2vec2-conformer"
+
+ def __init__(
+ self,
+ vocab_size=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ feat_proj_dropout=0.0,
+ feat_quantizer_dropout=0.0,
+ final_dropout=0.1,
+ layerdrop=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ feat_extract_norm="group",
+ feat_extract_activation="gelu",
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
+ conv_bias=False,
+ num_conv_pos_embeddings=128,
+ num_conv_pos_embedding_groups=16,
+ apply_spec_augment=True,
+ mask_time_prob=0.05,
+ mask_time_length=10,
+ mask_time_min_masks=2,
+ mask_feature_prob=0.0,
+ mask_feature_length=10,
+ mask_feature_min_masks=0,
+ num_codevectors_per_group=320,
+ num_codevector_groups=2,
+ contrastive_logits_temperature=0.1,
+ num_negatives=100,
+ codevector_dim=256,
+ proj_codevector_dim=256,
+ diversity_loss_weight=0.1,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ use_weighted_layer_sum=False,
+ classifier_proj_size=256,
+ tdnn_dim=(512, 512, 512, 512, 1500),
+ tdnn_kernel=(5, 3, 3, 1, 1),
+ tdnn_dilation=(1, 2, 3, 1, 1),
+ xvector_output_dim=512,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ add_adapter=False,
+ adapter_kernel_size=3,
+ adapter_stride=2,
+ num_adapter_layers=3,
+ output_hidden_size=None,
+ position_embeddings_type="relative",
+ rotary_embedding_base=10000,
+ max_source_positions=5000,
+ conv_depthwise_kernel_size=31,
+ conformer_conv_dropout=0.1,
+ **kwargs
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = list(conv_dim)
+ self.conv_stride = list(conv_stride)
+ self.conv_kernel = list(conv_kernel)
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_feat_extract_layers = len(self.conv_dim)
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.feat_proj_dropout = feat_proj_dropout
+ self.final_dropout = final_dropout
+ self.layerdrop = layerdrop
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.use_weighted_layer_sum = use_weighted_layer_sum
+ self.max_source_positions = max_source_positions
+ self.position_embeddings_type = position_embeddings_type
+ self.rotary_embedding_base = rotary_embedding_base
+
+ if (
+ (len(self.conv_stride) != self.num_feat_extract_layers)
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
+ ):
+ raise ValueError(
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ )
+
+ # Conformer-block related
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
+ self.conformer_conv_dropout = conformer_conv_dropout
+
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
+ self.apply_spec_augment = apply_spec_augment
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.mask_time_min_masks = mask_time_min_masks
+ self.mask_feature_prob = mask_feature_prob
+ self.mask_feature_length = mask_feature_length
+ self.mask_feature_min_masks = mask_feature_min_masks
+
+ # parameters for pretraining with codevector quantized representations
+ self.num_codevectors_per_group = num_codevectors_per_group
+ self.num_codevector_groups = num_codevector_groups
+ self.contrastive_logits_temperature = contrastive_logits_temperature
+ self.feat_quantizer_dropout = feat_quantizer_dropout
+ self.num_negatives = num_negatives
+ self.codevector_dim = codevector_dim
+ self.proj_codevector_dim = proj_codevector_dim
+ self.diversity_loss_weight = diversity_loss_weight
+
+ # ctc loss
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # adapter
+ self.add_adapter = add_adapter
+ self.adapter_kernel_size = adapter_kernel_size
+ self.adapter_stride = adapter_stride
+ self.num_adapter_layers = num_adapter_layers
+ self.output_hidden_size = output_hidden_size or hidden_size
+
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
+ self.classifier_proj_size = classifier_proj_size
+
+ # XVector-specific parameters. Feel free to ignore for other classes.
+ self.tdnn_dim = list(tdnn_dim)
+ self.tdnn_kernel = list(tdnn_kernel)
+ self.tdnn_dilation = list(tdnn_dilation)
+ self.xvector_output_dim = xvector_output_dim
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return functools.reduce(operator.mul, self.conv_stride, 1)
diff --git a/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 000000000000..26ccf9239b61
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,307 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Wav2Vec2Conformer checkpoint."""
+
+
+import argparse
+import json
+import os
+
+import fairseq
+import torch
+from fairseq.data import Dictionary
+
+from transformers import (
+ Wav2Vec2ConformerConfig,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2CTCTokenizer,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ logging,
+)
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+MAPPING = {
+ "post_extract_proj": "feature_projection.projection",
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
+ "self_attn.linear_k": "encoder.layers.*.self_attn.linear_k",
+ "self_attn.linear_v": "encoder.layers.*.self_attn.linear_v",
+ "self_attn.linear_q": "encoder.layers.*.self_attn.linear_q",
+ "self_attn.pos_bias_u": "encoder.layers.*.self_attn.pos_bias_u",
+ "self_attn.pos_bias_v": "encoder.layers.*.self_attn.pos_bias_v",
+ "self_attn.linear_out": "encoder.layers.*.self_attn.linear_out",
+ "self_attn.linear_pos": "encoder.layers.*.self_attn.linear_pos",
+ "self_attn.rotary_emb": "encoder.embed_positions",
+ "self_attn_layer_norm": "encoder.layers.*.self_attn_layer_norm",
+ "conv_module.pointwise_conv1": "encoder.layers.*.conv_module.pointwise_conv1",
+ "conv_module.pointwise_conv2": "encoder.layers.*.conv_module.pointwise_conv2",
+ "conv_module.depthwise_conv": "encoder.layers.*.conv_module.depthwise_conv",
+ "conv_module.batch_norm": "encoder.layers.*.conv_module.batch_norm",
+ "conv_module.layer_norm": "encoder.layers.*.conv_module.layer_norm",
+ "ffn1.w_1": "encoder.layers.*.ffn1.intermediate_dense",
+ "ffn1.w_2": "encoder.layers.*.ffn1.output_dense",
+ "ffn1.layer_norm": "encoder.layers.*.ffn1_layer_norm",
+ "ffn2.w_1": "encoder.layers.*.ffn2.intermediate_dense",
+ "ffn2.w_2": "encoder.layers.*.ffn2.output_dense",
+ "ffn2.layer_norm": "encoder.layers.*.ffn2_layer_norm",
+ "final_layer_norm": "encoder.layers.*.final_layer_norm",
+ "encoder.layer_norm": "encoder.layer_norm",
+ "w2v_model.layer_norm": "feature_projection.layer_norm",
+ "quantizer.weight_proj": "quantizer.weight_proj",
+ "quantizer.vars": "quantizer.codevectors",
+ "project_q": "project_q",
+ "final_proj": "project_hid",
+ "w2v_encoder.proj": "lm_head",
+ "mask_emb": "masked_spec_embed",
+}
+TOP_LEVEL_KEYS = [
+ "lm_head",
+ "quantizer.weight_proj",
+ "quantizer.codevectors",
+ "project_q",
+ "project_hid",
+]
+
+
+def set_recursively(hf_pointer, key, value, full_name, weight_type):
+ for attribute in key.split("."):
+ hf_pointer = getattr(hf_pointer, attribute)
+
+ if weight_type is not None:
+ hf_shape = getattr(hf_pointer, weight_type).shape
+ else:
+ hf_shape = hf_pointer.shape
+
+ if hf_shape != value.shape:
+ raise ValueError(
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
+
+ if weight_type == "weight":
+ hf_pointer.weight.data = value
+ elif weight_type == "weight_g":
+ hf_pointer.weight_g.data = value
+ elif weight_type == "weight_v":
+ hf_pointer.weight_v.data = value
+ elif weight_type == "bias":
+ hf_pointer.bias.data = value
+ elif weight_type == "running_mean":
+ hf_pointer.running_mean.data = value
+ elif weight_type == "running_var":
+ hf_pointer.running_var.data = value
+ elif weight_type == "num_batches_tracked":
+ hf_pointer.num_batches_tracked.data = value
+ elif weight_type == "inv_freq":
+ hf_pointer.inv_freq.data = value
+ else:
+ hf_pointer.data = value
+
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
+
+
+def recursively_load_weights(fairseq_model, hf_model, is_headless):
+ unused_weights = []
+ fairseq_dict = fairseq_model.state_dict()
+
+ feature_extractor = hf_model.wav2vec2_conformer.feature_extractor
+
+ for name, value in fairseq_dict.items():
+ is_used = False
+ if "conv_layers" in name:
+ load_conv_layer(
+ name,
+ value,
+ feature_extractor,
+ unused_weights,
+ hf_model.config.feat_extract_norm == "group",
+ )
+ is_used = True
+ else:
+ for key, mapped_key in MAPPING.items():
+ mapped_key = "wav2vec2_conformer." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
+ is_used = True
+ if "*" in mapped_key:
+ layer_index = name.split(key)[0].split(".")[-2]
+ mapped_key = mapped_key.replace("*", layer_index)
+ if "pos_bias_u" in name:
+ weight_type = None
+ elif "pos_bias_v" in name:
+ weight_type = None
+ elif "weight_g" in name:
+ weight_type = "weight_g"
+ elif "weight_v" in name:
+ weight_type = "weight_v"
+ elif "bias" in name:
+ weight_type = "bias"
+ elif "weight" in name:
+ # TODO: don't match quantizer.weight_proj
+ weight_type = "weight"
+ elif "running_mean" in name:
+ weight_type = "running_mean"
+ elif "inv_freq" in name:
+ weight_type = "inv_freq"
+ elif "running_var" in name:
+ weight_type = "running_var"
+ elif "num_batches_tracked" in name:
+ weight_type = "num_batches_tracked"
+ else:
+ weight_type = None
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
+ continue
+ if not is_used:
+ unused_weights.append(name)
+
+ logger.warning(f"Unused weights: {unused_weights}")
+
+
+# Copied from transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.load_conv_layer
+def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
+ name = full_name.split("conv_layers.")[-1]
+ items = name.split(".")
+ layer_id = int(items[0])
+ type_id = int(items[1])
+
+ if type_id == 0:
+ if "bias" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].conv.bias.data = value
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
+ elif "weight" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].conv.weight.data = value
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
+ elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
+ if "bias" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
+ elif "weight" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
+ else:
+ unused_weights.append(full_name)
+
+
+@torch.no_grad()
+def convert_wav2vec2_conformer_checkpoint(
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
+):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = Wav2Vec2ConformerConfig.from_pretrained(config_path, hidden_act="swish")
+ else:
+ config = Wav2Vec2ConformerConfig()
+
+ if "rope" in checkpoint_path:
+ config.position_embeddings_type = "rotary"
+
+ if is_finetuned:
+ if dict_path:
+ target_dict = Dictionary.load(dict_path)
+
+ # important change bos & pad token id since CTC symbol is and
+ # not as in fairseq
+ config.bos_token_id = target_dict.pad_index
+ config.pad_token_id = target_dict.bos_index
+ config.eos_token_id = target_dict.eos_index
+ config.vocab_size = len(target_dict.symbols)
+ vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
+ if not os.path.isdir(pytorch_dump_folder_path):
+ logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
+ return
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
+ vocab_dict = target_dict.indices
+
+ # fairseq has the and switched
+ vocab_dict[""] = 0
+ vocab_dict[""] = 1
+ with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
+ json.dump(vocab_dict, vocab_handle)
+ tokenizer = Wav2Vec2CTCTokenizer(
+ vocab_path,
+ unk_token=target_dict.unk_word,
+ pad_token=target_dict.pad_word,
+ bos_token=target_dict.bos_word,
+ eos_token=target_dict.eos_word,
+ word_delimiter_token="|",
+ do_lower_case=False,
+ )
+ return_attention_mask = True if config.feat_extract_norm == "layer" else False
+ feature_extractor = Wav2Vec2FeatureExtractor(
+ feature_size=1,
+ sampling_rate=16000,
+ padding_value=0,
+ do_normalize=True,
+ return_attention_mask=return_attention_mask,
+ )
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ hf_wav2vec = Wav2Vec2ConformerForCTC(config)
+ else:
+ hf_wav2vec = Wav2Vec2ConformerForPreTraining(config)
+
+ if is_finetuned:
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
+ [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
+ )
+ else:
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
+
+ model = model[0].eval()
+
+ recursively_load_weights(model, hf_wav2vec, not is_finetuned)
+
+ hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ parser.add_argument(
+ "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
+ )
+ args = parser.parse_args()
+ convert_wav2vec2_conformer_checkpoint(
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
+ )
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
new file mode 100644
index 000000000000..4c4962b155c3
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -0,0 +1,2127 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import torch_int_div
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+# Audio class docstring
+_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
+_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-seq-class"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
+_SEQ_CLASS_EXPECTED_LOSS = 0.68
+
+# Frame class docstring
+_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-frame-class"
+_FRAME_EXPECTED_OUTPUT = [1, 0]
+
+# Speaker Verification docstring
+_XVECTOR_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-xvector"
+_XVECTOR_EXPECTED_OUTPUT = 1.0
+
+
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/wav2vec2-conformer-large-rel-pos",
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
+]
+
+
+@dataclass
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: torch.FloatTensor = None
+ projected_quantized_states: torch.FloatTensor = None
+ codevector_perplexity: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ contrastive_loss: Optional[torch.FloatTensor] = None
+ diversity_loss: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.sum(-1).detach().tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+ """
+ Sample `num_negatives` vectors from feature vectors.
+ """
+ batch_size, sequence_length = features_shape
+
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+ sequence_length_range = np.arange(sequence_length)
+
+ # get `num_negatives` random vector indices from the same utterance
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+ mask_time_indices = (
+ mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
+ )
+
+ for batch_idx in range(batch_size):
+ high = mask_time_indices[batch_idx].sum() - 1
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+ # avoid sampling the same positive vector, but keep the distribution uniform
+ sampled_indices[sampled_indices >= feature_indices] += 1
+
+ # remap to actual indices
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+ # correct for batch size
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+ return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
+ else:
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
+ return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (iWav2Vec2Conformer
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ if self._requires_grad and self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(conv_layer),
+ hidden_states,
+ )
+ else:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerConvolutionModule(nn.Module):
+ """Convolution block used in the conformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.glu = torch.nn.GLU(dim=1)
+ self.depthwise_conv = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ config.conv_depthwise_kernel_size,
+ stride=1,
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.layer_norm(hidden_states)
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism
+ # => (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+ """Construct an Wav2Vec2ConformerSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_size = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ if self.position_embeddings_type == "relative":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
+ " 'relative'"
+ )
+ # apply relative_position_embeddings to qk scores
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
+ scores = self._apply_relative_embeddings(
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
+ )
+ else:
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+
+ # apply attention_mask if necessary
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ # => (batch, head, time1, time2)
+ probs = torch.softmax(scores, dim=-1)
+ probs = self.dropout(probs)
+
+ # => (batch, head, time1, d_k)
+ hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states[~attention_mask] = 0.0
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ # create gradient checkpointing function
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
+ """
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_groups = config.num_codevector_groups
+ self.num_vars = config.num_codevectors_per_group
+
+ if config.codevector_dim % self.num_groups != 0:
+ raise ValueError(
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ )
+
+ # storage for codebook variables (codewords)
+ self.codevectors = nn.Parameter(
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+ )
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+ # can be decayed for training
+ self.temperature = 2
+
+ @staticmethod
+ def _compute_perplexity(probs, mask=None):
+ if mask is not None:
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
+ marginal_probs = probs.sum(dim=0) / mask.sum()
+ else:
+ marginal_probs = probs.mean(dim=0)
+
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states, mask_time_indices=None):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2ConformerConfig
+ base_model_prefix = "wav2vec2_conformer"
+ main_input_name = "input_values"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # gumbel softmax requires special init
+ if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch_int_div(input_length - kernel_size, stride) + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
+ module.gradient_checkpointing = value
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
+ and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [wav2vec2_conformer-base](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos),
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+ that these models also yield slightly different results depending on whether `input_values` is padded or
+ not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2ConformerEncoder(config)
+
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # make sure that project_hid & project_q are initialized like normal linear layers
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 0.1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
+ target_features
+ )
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.BoolTensor] = None,
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+ Required input for pre-training.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
+ >>> from datasets import load_dataset
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
+
+ >>> # compute masked indices
+ >>> batch_size, raw_sequence_length = input_values.shape
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
+ >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
+ >>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
+
+ >>> with torch.no_grad():
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
+
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ >>> # show that cosine similarity is much higher than random
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
+ tensor(True)
+
+ >>> # for contrastive loss training model should be put into train mode
+ >>> model = model.train()
+ >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if mask_time_indices is not None:
+ mask_time_indices = mask_time_indices.to(torch.bool)
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ mask_time_indices=mask_time_indices,
+ return_dict=return_dict,
+ )
+
+ # 1. project all transformed features (including masked) to final vq dim
+ transformer_features = self.project_hid(outputs[0])
+
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+
+ if attention_mask is not None:
+ # compute reduced attention_mask correponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ quantized_features, codevector_perplexity = self.quantizer(
+ extract_features, mask_time_indices=mask_time_indices
+ )
+ quantized_features = self.project_q(quantized_features)
+
+ loss = contrastive_loss = diversity_loss = None
+ if sampled_negative_indices is not None:
+ batch_size, sequence_length, hidden_size = quantized_features.shape
+
+ # for training, we sample negatives
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
+ # sample negative quantized vectors BTC => (BxT)C
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
+ sampled_negative_indices.long().view(-1)
+ ]
+ negative_quantized_features = negative_quantized_features.view(
+ batch_size, sequence_length, -1, hidden_size
+ ).permute(2, 0, 1, 3)
+
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
+ logits = self.compute_contrastive_logits(
+ quantized_features[None, :],
+ negative_quantized_features,
+ transformer_features,
+ self.config.contrastive_logits_temperature,
+ )
+
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
+ # its cosine similarity will be masked
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
+
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
+
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
+ # 7. compute diversity loss: \mathbf{L}_d
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
+
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
+
+ if not return_dict:
+ if loss is not None:
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return Wav2Vec2ConformerForPreTrainingOutput(
+ loss=loss,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ contrastive_loss=contrastive_loss,
+ diversity_loss=diversity_loss,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
+ " vocabulary size of the language model head. Please instantiate the model as follows:"
+ " `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
+ " your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters"
+ " (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ hidden_states[~padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters"
+ " (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_FRAME_CLASS_CHECKPOINT,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_FRAME_EXPECTED_OUTPUT,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super(AMSoftmaxLoss, self).__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = nn.functional.unfold(
+ hidden_states,
+ (self.kernel_size, self.in_conv_dim),
+ stride=(1, self.in_conv_dim),
+ dilation=(self.dilation, 1),
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.kernel(hidden_states)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_XVECTOR_CHECKPOINT,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_XVECTOR_EXPECTED_OUTPUT,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/wav2vec2_phoneme/__init__.py b/src/transformers/models/wav2vec2_phoneme/__init__.py
index 4d6ea18a330e..84dc9942d515 100644
--- a/src/transformers/models/wav2vec2_phoneme/__init__.py
+++ b/src/transformers/models/wav2vec2_phoneme/__init__.py
@@ -20,11 +20,7 @@
from ...utils import _LazyModule
-# fmt: off
-_import_structure = {
- "tokenization_wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"]
-}
-# fmt: on
+_import_structure = {"tokenization_wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
index 6bd355645e5a..c983c4be8264 100644
--- a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
+++ b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
@@ -55,10 +55,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/wav2vec2-lv-60-espeak-cv-ft": "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/vocab.json",
+ "facebook/wav2vec2-lv-60-espeak-cv-ft": (
+ "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/vocab.json"
+ ),
},
"tokenizer_config_file": {
- "facebook/wav2vec2-lv-60-espeak-cv-ft": "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/tokenizer_config.json",
+ "facebook/wav2vec2-lv-60-espeak-cv-ft": (
+ "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/tokenizer_config.json"
+ ),
},
}
@@ -369,7 +373,7 @@ def convert_tokens_to_string(
if len(char_offsets) != len(processed_chars):
raise ValueError(
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
- f" have to be of the same length, but are: `len(offsets)`: "
+ " have to be of the same length, but are: `len(offsets)`: "
f"{len(char_offsets)} and `len(processed_tokens)`: {len(processed_chars)}"
)
@@ -564,7 +568,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
@@ -600,7 +604,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
tokens_to_add = []
for token in new_tokens:
if not isinstance(token, str):
- raise ValueError(f"Token {token} has to be of type string, but is " f"of type {type(token)}.")
+ raise ValueError(f"Token {token} has to be of type string, but is of type {type(token)}.")
assert isinstance(token, str)
if (
token != self.unk_token
diff --git a/src/transformers/models/wav2vec2_with_lm/__init__.py b/src/transformers/models/wav2vec2_with_lm/__init__.py
index 8730f3508e30..174946ae1018 100644
--- a/src/transformers/models/wav2vec2_with_lm/__init__.py
+++ b/src/transformers/models/wav2vec2_with_lm/__init__.py
@@ -20,11 +20,7 @@
from ...utils import _LazyModule
-# fmt: off
-_import_structure = {
- "processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
-}
-# fmt: on
+_import_structure = {"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
index 4e7da075261b..f09b5eb922ab 100644
--- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
@@ -16,6 +16,7 @@
Speech processor class for Wav2Vec2
"""
import os
+import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import get_context
@@ -99,6 +100,7 @@ def __init__(
self.decoder = decoder
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
def save_pretrained(self, save_directory):
super().save_pretrained(save_directory)
@@ -214,7 +216,35 @@ def __call__(self, *args, **kwargs):
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two
methods for more information.
"""
- return self.current_processor(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
def pad(self, *args, **kwargs):
"""
@@ -224,7 +254,28 @@ def pad(self, *args, **kwargs):
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods
for more information.
"""
- return self.current_processor.pad(*args, **kwargs)
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor.pad(*args, **kwargs)
+
+ input_features = kwargs.pop("input_features", None)
+ labels = kwargs.pop("labels", None)
+ if len(args) > 0:
+ input_features = args[0]
+ args = args[1:]
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
def batch_decode(
self,
@@ -486,9 +537,16 @@ def decode(
@contextmanager
def as_target_processor(self):
"""
- Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
+ Temporarily sets the processor for processing the target. Useful for encoding the labels when fine-tuning
Wav2Vec2.
"""
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
diff --git a/src/transformers/models/wavlm/__init__.py b/src/transformers/models/wavlm/__init__.py
index 576bbaf83cdf..9cd64b25dafa 100644
--- a/src/transformers/models/wavlm/__init__.py
+++ b/src/transformers/models/wavlm/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"],
-}
+_import_structure = {"configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_wavlm"] = [
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"WavLMForAudioFrameClassification",
@@ -38,7 +41,12 @@
if TYPE_CHECKING:
from .configuration_wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py
index d7f0b7047030..7c908d3d7300 100644
--- a/src/transformers/models/wavlm/configuration_wavlm.py
+++ b/src/transformers/models/wavlm/configuration_wavlm.py
@@ -77,15 +77,15 @@ class WavLMConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
- of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
- length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether the 1D convolutional layers have a bias.
@@ -146,13 +146,13 @@ class WavLMConfig(PretrainedConfig):
instance of [`WavLMForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -290,10 +290,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
index 8523fa87eba8..91758cc95952 100644
--- a/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
@@ -74,9 +74,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -144,28 +145,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py
index c2eb193160ad..c792a368cb47 100755
--- a/src/transformers/models/wavlm/modeling_wavlm.py
+++ b/src/transformers/models/wavlm/modeling_wavlm.py
@@ -16,7 +16,6 @@
import math
import warnings
-from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
@@ -28,16 +27,17 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
-from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
-)
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_wavlm import WavLMConfig
@@ -80,67 +80,6 @@
]
-@dataclass
-class WavLMBaseModelOutput(ModelOutput):
- """
- Output type of [`WavLMBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@@ -244,7 +183,7 @@ def compute_num_masked_span(input_length):
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that that indexes now create a span
+ # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
@@ -620,12 +559,12 @@ def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> t
relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
- relative_postion_if_large = (max_exact + relative_positions_if_large).to(torch.long)
- relative_postion_if_large = torch.min(
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
- relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+ relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
return relative_buckets
@@ -1184,7 +1123,7 @@ def _set_gradient_checkpointing(self, module, value=False):
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
class WavLMModel(WavLMPreTrainedModel):
def __init__(self, config: WavLMConfig):
super().__init__(config)
@@ -1275,7 +1214,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=WavLMBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1288,7 +1227,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, WavLMBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1325,7 +1264,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return WavLMBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1606,6 +1545,7 @@ def __init__(self, config):
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1649,6 +1589,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1681,12 +1622,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/xglm/__init__.py b/src/transformers/models/xglm/__init__.py
index d5934dea6666..2ab60e4cb4bb 100644
--- a/src/transformers/models/xglm/__init__.py
+++ b/src/transformers/models/xglm/__init__.py
@@ -19,6 +19,7 @@
# rely on isort to merge the imports
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
-}
+_import_structure = {"configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xglm"] = ["XGLMTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xglm"] = [
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XGLMForCausalLM",
@@ -46,7 +60,12 @@
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_xglm"] = [
"FlaxXGLMForCausalLM",
"FlaxXGLMModel",
@@ -57,16 +76,36 @@
if TYPE_CHECKING:
from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xglm import XGLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xglm_fast import XGLMTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index f26c7fa81839..6717d8d8e152 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -90,11 +90,11 @@
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- ``input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
- `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
- can choose to directly pass an embedded representation. This is useful if you want more control over how to
- convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
+ sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to
+ directly pass an embedded representation. This is useful if you want more control over how to convert
+ `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
`past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -173,9 +173,7 @@ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Opt
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
- self.weights = nn.Parameter(emb_weights)
- self.weights.requires_grad = False
- self.weights.detach_()
+ self.register_buffer("weights", emb_weights)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
@@ -330,7 +328,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -346,7 +345,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -367,7 +367,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -574,7 +575,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -709,7 +710,7 @@ def forward(
hidden_states = inputs_embeds + positions
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -722,7 +723,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -738,7 +740,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache ="
+ " False`..."
)
use_cache = False
diff --git a/src/transformers/models/xlm/__init__.py b/src/transformers/models/xlm/__init__.py
index f0a42e244e7e..de9be348b94c 100644
--- a/src/transformers/models/xlm/__init__.py
+++ b/src/transformers/models/xlm/__init__.py
@@ -18,15 +18,20 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
- "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"],
+ "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
"tokenization_xlm": ["XLMTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm"] = [
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMForMultipleChoice",
@@ -39,7 +44,12 @@
"XLMWithLMHeadModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlm"] = [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMForMultipleChoice",
@@ -54,10 +64,15 @@
if TYPE_CHECKING:
- from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
+ from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
from .tokenization_xlm import XLMTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice,
@@ -70,7 +85,12 @@
XLMWithLMHeadModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
diff --git a/src/transformers/models/xlm/configuration_xlm.py b/src/transformers/models/xlm/configuration_xlm.py
index d6f70c6671cc..e14ad2ec6cae 100644
--- a/src/transformers/models/xlm/configuration_xlm.py
+++ b/src/transformers/models/xlm/configuration_xlm.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" XLM configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -228,3 +231,20 @@ def __init__(
self.n_words = kwargs["n_words"]
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
+
+
+# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
+class XLMOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py
index 24d32f798f3d..3b5f1c6e2650 100644
--- a/src/transformers/models/xlm/modeling_tf_xlm.py
+++ b/src/transformers/models/xlm/modeling_tf_xlm.py
@@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask
else:
# assert lengths.max().item() <= slen
- alen = tf.range(slen)
- mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
+ alen = tf.range(slen, dtype=lengths.dtype)
+ mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:
@@ -797,6 +797,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
+ # XLM does not have past caching features
+ self.supports_xla_generation = False
def get_lm_head(self):
return self.pred_layer
diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py
index ebb3c503475c..79a1e0292e99 100755
--- a/src/transformers/models/xlm/modeling_xlm.py
+++ b/src/transformers/models/xlm/modeling_xlm.py
@@ -181,7 +181,7 @@ def unshape(x):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
- scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen)
+ scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
diff --git a/src/transformers/models/xlm/tokenization_xlm.py b/src/transformers/models/xlm/tokenization_xlm.py
index f6c94f11ae46..bd7b58eb053b 100644
--- a/src/transformers/models/xlm/tokenization_xlm.py
+++ b/src/transformers/models/xlm/tokenization_xlm.py
@@ -697,7 +697,8 @@ def ja_tokenize(self, text):
)
except (AttributeError, ImportError):
logger.error(
- "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
+ "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
+ " (https://github.com/chezou/Mykytea-python) with the following steps"
)
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i")
@@ -801,7 +802,8 @@ def _tokenize(self, text, lang="en", bypass_tokenizer=False):
"""
if lang and self.lang2id and lang not in self.lang2id:
logger.error(
- "Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
+ "Supplied language code not found in lang2id mapping. Please check that your language is supported by"
+ " the loaded pretrained model."
)
if bypass_tokenizer:
text = text.split()
@@ -963,7 +965,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/xlm_prophetnet/__init__.py b/src/transformers/models/xlm_prophetnet/__init__.py
index fe69b5060765..8fbec3d400ed 100644
--- a/src/transformers/models/xlm_prophetnet/__init__.py
+++ b/src/transformers/models/xlm_prophetnet/__init__.py
@@ -17,20 +17,27 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available
_import_structure = {
- "configuration_xlm_prophetnet": [
- "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "XLMProphetNetConfig",
- ],
+ "configuration_xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_prophetnet"] = ["XLMProphetNetTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_prophetnet"] = [
"XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMProphetNetDecoder",
@@ -44,10 +51,20 @@
if TYPE_CHECKING:
from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMProphetNetDecoder,
diff --git a/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
index 2c3d21bd283c..3025ed29f643 100644
--- a/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
+++ b/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
@@ -22,7 +22,9 @@
logger = logging.get_logger(__name__)
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json",
+ "microsoft/xprophetnet-large-wiki100-cased": (
+ "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
index 48f68238f126..af8308287939 100644
--- a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
+++ b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
@@ -30,7 +30,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/prophetnet.tokenizer",
+ "microsoft/xprophetnet-large-wiki100-cased": (
+ "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/prophetnet.tokenizer"
+ ),
}
}
@@ -159,8 +161,8 @@ def __init__(
import sentencepiece as spm
except ImportError:
logger.warning(
- "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece "
- "pip install sentencepiece"
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
)
raise
@@ -198,8 +200,8 @@ def __setstate__(self, d):
import sentencepiece as spm
except ImportError:
logger.warning(
- "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece "
- "pip install sentencepiece"
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
)
raise
diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py
index a29a400c8b7d..60d26c131484 100644
--- a/src/transformers/models/xlm_roberta/__init__.py
+++ b/src/transformers/models/xlm_roberta/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -36,13 +37,28 @@
],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_roberta"] = ["XLMRobertaTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_roberta_fast"] = ["XLMRobertaTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_roberta"] = [
"XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaForCausalLM",
@@ -54,7 +70,12 @@
"XLMRobertaModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlm_roberta"] = [
"TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMRobertaForMaskedLM",
@@ -65,7 +86,12 @@
"TFXLMRobertaModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_xlm_roberta"] = [
"FlaxXLMRobertaForMaskedLM",
"FlaxXLMRobertaForMultipleChoice",
@@ -82,13 +108,28 @@
XLMRobertaOnnxConfig,
)
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_roberta import XLMRobertaTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM,
@@ -100,7 +141,12 @@
XLMRobertaModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM,
@@ -111,7 +157,12 @@
TFXLMRobertaModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_xlm_roberta import (
FlaxXLMRobertaForMaskedLM,
FlaxXLMRobertaForMultipleChoice,
diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
index c1469bfca4cf..194b38a8c181 100644
--- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
@@ -27,10 +27,18 @@
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/config.json",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
index 072933a12ea6..40928d8dc306 100644
--- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
@@ -35,10 +35,18 @@
"vocab_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
index 119d2fa080f2..f99e3c086a88 100644
--- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
+++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
@@ -38,18 +38,34 @@
"vocab_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/tokenizer.json",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/xlm_roberta_xl/__init__.py b/src/transformers/models/xlm_roberta_xl/__init__.py
index 765a235f29e3..3140e3bd2267 100644
--- a/src/transformers/models/xlm_roberta_xl/__init__.py
+++ b/src/transformers/models/xlm_roberta_xl/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -29,7 +29,12 @@
],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_roberta_xl"] = [
"XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaXLForCausalLM",
@@ -49,7 +54,12 @@
XLMRobertaXLOnnxConfig,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_roberta_xl import (
XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaXLForCausalLM,
diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
index ab46aa8f0322..aa41466767d6 100644
--- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
+++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
@@ -19,7 +19,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -35,7 +34,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -76,7 +80,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
@@ -415,7 +419,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/xlnet/__init__.py b/src/transformers/models/xlnet/__init__.py
index 599448a271df..d01edf267cc1 100644
--- a/src/transformers/models/xlnet/__init__.py
+++ b/src/transformers/models/xlnet/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
-}
+_import_structure = {"configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlnet"] = ["XLNetTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlnet_fast"] = ["XLNetTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlnet"] = [
"XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLNetForMultipleChoice",
@@ -51,7 +65,12 @@
"load_tf_weights_in_xlnet",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlnet"] = [
"TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLNetForMultipleChoice",
@@ -68,13 +87,28 @@
if TYPE_CHECKING:
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlnet import XLNetTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlnet_fast import XLNetTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice,
@@ -88,7 +122,12 @@
load_tf_weights_in_xlnet,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice,
diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py
index bc6f0f68356f..5448f9248ced 100644
--- a/src/transformers/models/xlnet/configuration_xlnet.py
+++ b/src/transformers/models/xlnet/configuration_xlnet.py
@@ -219,7 +219,8 @@ def __init__(
if "use_cache" in kwargs:
warnings.warn(
- "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval` instead.",
+ "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval`"
+ " instead.",
FutureWarning,
)
use_mems_eval = kwargs["use_cache"]
diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
index f6fc73ca0e58..804b52b0dc87 100755
--- a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
@@ -88,8 +88,10 @@ def convert_xlnet_checkpoint_to_pytorch(
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained XLNet model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained XLNet model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py
index df4111d26317..2e2fb1ea0875 100644
--- a/src/transformers/models/xlnet/modeling_tf_xlnet.py
+++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py
@@ -1192,6 +1192,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
+ # generate fails to convert to a graph with XLNet
+ self.supports_xla_generation = False
def get_lm_head(self):
return self.lm_loss
@@ -1202,7 +1204,6 @@ def get_prefix_bias_name(self):
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
-
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
@@ -1212,12 +1213,12 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwar
offset = 2
if past:
- inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
+ input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else:
- inputs = tf.concat([inputs, dummy_token], axis=1)
+ input_ids = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token
- sequence_length = inputs.shape[1]
+ sequence_length = input_ids.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
@@ -1228,7 +1229,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwar
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {
- "input_ids": inputs,
+ "input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_mems": use_mems,
diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py
index dc7f78eeb8e2..4a299a5a657f 100755
--- a/src/transformers/models/xlnet/modeling_xlnet.py
+++ b/src/transformers/models/xlnet/modeling_xlnet.py
@@ -1056,7 +1056,6 @@ def relative_positional_encoding(self, qlen, klen, bsz=None):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
- pos_emb = pos_emb.to(self.device)
return pos_emb
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@@ -1092,7 +1091,8 @@ def forward(
if "use_cache" in kwargs:
warnings.warn(
- "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.",
+ "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems`"
+ " instead.",
FutureWarning,
)
use_mems = kwargs["use_cache"]
@@ -1205,6 +1205,7 @@ def forward(
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
+ pos_emb = pos_emb.to(output_h.device)
pos_emb = self.dropout(pos_emb)
# Prepare head mask if needed
diff --git a/src/transformers/models/yolos/__init__.py b/src/transformers/models/yolos/__init__.py
index fcdf387c68d6..6ae73421a831 100644
--- a/src/transformers/models/yolos/__init__.py
+++ b/src/transformers/models/yolos/__init__.py
@@ -17,17 +17,25 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
-}
+_import_structure = {"configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig", "YolosOnnxConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_yolos"] = [
"YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
"YolosForObjectDetection",
@@ -37,12 +45,22 @@
if TYPE_CHECKING:
- from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
+ from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig, YolosOnnxConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_yolos import YolosFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_yolos import (
YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
YolosForObjectDetection,
diff --git a/src/transformers/models/yolos/configuration_yolos.py b/src/transformers/models/yolos/configuration_yolos.py
index cd3414a7f26e..179d2833a121 100644
--- a/src/transformers/models/yolos/configuration_yolos.py
+++ b/src/transformers/models/yolos/configuration_yolos.py
@@ -14,7 +14,13 @@
# limitations under the License.
""" YOLOS model configuration"""
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -151,3 +157,24 @@ def __init__(
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
+
+
+class YolosOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 12
diff --git a/src/transformers/models/yolos/convert_yolos_to_pytorch.py b/src/transformers/models/yolos/convert_yolos_to_pytorch.py
index add0ae772db1..7f4161a632d8 100644
--- a/src/transformers/models/yolos/convert_yolos_to_pytorch.py
+++ b/src/transformers/models/yolos/convert_yolos_to_pytorch.py
@@ -247,7 +247,10 @@ def convert_yolos_checkpoint(yolos_name, checkpoint_path, pytorch_dump_folder_pa
"--yolos_name",
default="yolos_s_200_pre",
type=str,
- help="Name of the YOLOS model you'd like to convert. Should be one of 'yolos_ti', 'yolos_s_200_pre', 'yolos_s_300_pre', 'yolos_s_dWr', 'yolos_base'.",
+ help=(
+ "Name of the YOLOS model you'd like to convert. Should be one of 'yolos_ti', 'yolos_s_200_pre',"
+ " 'yolos_s_300_pre', 'yolos_s_dWr', 'yolos_base'."
+ ),
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, help="Path to the original state dict (.pth file)."
diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py
index 76b64ec83775..e199d1ae7bf4 100644
--- a/src/transformers/models/yolos/feature_extraction_yolos.py
+++ b/src/transformers/models/yolos/feature_extraction_yolos.py
@@ -537,7 +537,8 @@ def __call__(
valid_masks_path = True
if not valid_masks_path:
raise ValueError(
- "The path to the directory containing the mask PNG files should be provided as a `pathlib.Path` object."
+ "The path to the directory containing the mask PNG files should be provided as a"
+ " `pathlib.Path` object."
)
if not is_batched:
diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py
index 86ef903167d6..447cec23de97 100755
--- a/src/transformers/models/yolos/modeling_yolos.py
+++ b/src/transformers/models/yolos/modeling_yolos.py
@@ -111,13 +111,6 @@ class YolosObjectDetectionOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-# Copied from transformers.models.vit.modeling_vit.to_2tuple
-def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
-
-
class YolosEmbeddings(nn.Module):
"""
Construct the CLS token, detection tokens, position and patch embeddings.
@@ -129,12 +122,7 @@ def __init__(self, config: YolosConfig) -> None:
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
- self.patch_embeddings = PatchEmbeddings(
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.hidden_size,
- )
+ self.patch_embeddings = YolosPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
@@ -228,32 +216,35 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
return scale_pos_embed
-# Based on timm implementation, which can be found here:
-# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
-class PatchEmbeddings(nn.Module):
+class YolosPatchEmbeddings(nn.Module):
"""
- Image to Patch Embedding.
-
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
"""
- def __init__(
- self,
- image_size: int = 224,
- patch_size: Union[int, Tuple[int, int]] = 16,
- num_channels: int = 3,
- embed_dim: int = 768,
- ):
+ def __init__(self, config):
super().__init__()
- image_size = to_2tuple(image_size)
- patch_size = to_2tuple(patch_size)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
+ self.num_channels = num_channels
self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
@@ -280,7 +271,7 @@ def __init__(self, config: YolosConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -312,7 +303,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -620,7 +611,7 @@ def __init__(self, config: YolosConfig, add_pooling_layer: bool = True):
# Initialize weights and apply final processing
self.post_init()
- def get_input_embeddings(self) -> PatchEmbeddings:
+ def get_input_embeddings(self) -> YolosPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
@@ -1078,7 +1069,7 @@ def forward(self, outputs, targets):
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
- # Compute the average number of target boxes accross all nodes, for normalization purposes
+ # Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
diff --git a/src/transformers/models/yoso/__init__.py b/src/transformers/models/yoso/__init__.py
index 5dff89595ca1..400a0303c0c7 100644
--- a/src/transformers/models/yoso/__init__.py
+++ b/src/transformers/models/yoso/__init__.py
@@ -18,14 +18,17 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
-_import_structure = {
- "configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
-}
+_import_structure = {"configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_yoso"] = [
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
"YosoForMaskedLM",
@@ -42,7 +45,12 @@
if TYPE_CHECKING:
from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_yoso import (
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
YosoForMaskedLM,
diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py
index 50013ca03209..085d46bdfb55 100644
--- a/src/transformers/models/yoso/modeling_yoso.py
+++ b/src/transformers/models/yoso/modeling_yoso.py
@@ -21,7 +21,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -35,7 +34,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_than_1_6,
+ prune_linear_layer,
+)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_yoso import YosoConfig
@@ -257,7 +261,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
@@ -1160,17 +1164,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1247,18 +1251,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py
index 6e3b4404cd04..6d665b35566f 100644
--- a/src/transformers/onnx/__main__.py
+++ b/src/transformers/onnx/__main__.py
@@ -15,9 +15,8 @@
from argparse import ArgumentParser
from pathlib import Path
-from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer
-from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
-from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
+from ..models.auto import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
+from ..onnx.utils import get_preprocessor
from ..utils import logging
from .convert import export, validate_model_outputs
from .features import FeaturesManager
@@ -43,6 +42,13 @@ def main():
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
+ parser.add_argument(
+ "--preprocessor",
+ type=str,
+ choices=["auto", "tokenizer", "feature_extractor", "processor"],
+ default="auto",
+ help="Which type of preprocessor to use. 'auto' tries to automatically detect it.",
+ )
# Retrieve CLI arguments
args = parser.parse_args()
@@ -51,15 +57,17 @@ def main():
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
- # Check the modality of the inputs and instantiate the appropriate preprocessor
- # TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well
- config = AutoConfig.from_pretrained(args.model)
- if config.model_type in TOKENIZER_MAPPING_NAMES:
+ # Instantiate the appropriate preprocessor
+ if args.preprocessor == "auto":
+ preprocessor = get_preprocessor(args.model)
+ elif args.preprocessor == "tokenizer":
preprocessor = AutoTokenizer.from_pretrained(args.model)
- elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
+ elif args.preprocessor == "feature_extractor":
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
+ elif args.preprocessor == "processor":
+ preprocessor = AutoProcessor.from_pretrained(args.model)
else:
- raise ValueError(f"Unsupported model type: {config.model_type}")
+ raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
# Allocate the model
model = FeaturesManager.get_model_from_feature(
diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py
index 6015bd374013..fdcc12bdcd1f 100644
--- a/src/transformers/onnx/config.py
+++ b/src/transformers/onnx/config.py
@@ -77,9 +77,22 @@ class OnnxConfig(ABC):
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
+ "image-segmentation": OrderedDict(
+ {
+ "logits": {0: "batch", 1: "sequence"},
+ "pred_boxes": {0: "batch", 1: "sequence"},
+ "pred_masks": {0: "batch", 1: "sequence"},
+ }
+ ),
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
+ "object-detection": OrderedDict(
+ {
+ "logits": {0: "batch", 1: "sequence"},
+ "pred_boxes": {0: "batch", 1: "sequence"},
+ }
+ ),
"question-answering": OrderedDict(
{
"start_logits": {0: "batch", 1: "sequence"},
@@ -293,7 +306,8 @@ def generate_dummy_inputs(
raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.warning("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -410,7 +424,8 @@ def num_layers(self) -> int:
"""
if not hasattr(self._config, "num_layers"):
raise AttributeError(
- "could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
+ "could not find the number of layers attribute in the model configuration, override the num_layers"
+ " property of the model OnnxConfig to solve this"
)
return self._config.num_layers
@@ -422,7 +437,8 @@ def num_attention_heads(self) -> int:
"""
if not hasattr(self._config, "num_attention_heads"):
raise AttributeError(
- "could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
+ "could not find the number of attention heads attribute in the model configuration, override the"
+ " num_attention_heads property of the model OnnxConfig to solve this"
)
return self._config.num_attention_heads
@@ -457,8 +473,10 @@ def generate_dummy_inputs(
)
if "attention_mask" in common_inputs:
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)],
+ dim=1,
)
common_inputs["past_key_values"] = []
@@ -469,7 +487,7 @@ def generate_dummy_inputs(
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
"""
- Fill the input_or_ouputs mapping with past_key_values dynamic axes considering.
+ Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
Args:
inputs_or_outputs: The mapping to fill.
@@ -528,7 +546,8 @@ def num_layers(self) -> Tuple[int]:
num_layers = (self._config.encoder_layers, self._config.decoder_layers)
else:
raise AttributeError(
- "could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
+ "could not find the number of encoder and decoder layers attributes in the model configuration,"
+ " override the num_layers property of the model OnnxConfig to solve this"
)
return num_layers
@@ -543,7 +562,9 @@ def num_attention_heads(self) -> Tuple[int]:
num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads)
else:
raise AttributeError(
- "could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
+ "could not find the number of attention heads for the encoder and the decoder attributes in the"
+ " model configuration, override the num_attention_heads property of the model OnnxConfig to solve"
+ " this"
)
return num_attention_heads
diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py
index 69aca2a43acc..a896b76a1cca 100644
--- a/src/transformers/onnx/convert.py
+++ b/src/transformers/onnx/convert.py
@@ -34,12 +34,14 @@
if is_torch_available():
from ..modeling_utils import PreTrainedModel
+ from ..pytorch_utils import is_torch_less_than_1_11
if is_tf_available():
from ..modeling_tf_utils import TFPreTrainedModel
if TYPE_CHECKING:
from ..feature_extraction_utils import FeatureExtractionMixin
+ from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
@@ -68,7 +70,7 @@ def check_onnxruntime_requirements(minimum_version: Version):
raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
- f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
)
except ImportError:
@@ -80,18 +82,19 @@ def check_onnxruntime_requirements(minimum_version: Version):
def export_pytorch(
- preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
+ preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
model: "PreTrainedModel",
config: OnnxConfig,
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
+ device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch model to an ONNX Intermediate Representation (IR)
Args:
- preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
+ preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`]):
The model to export.
@@ -101,6 +104,8 @@ def export_pytorch(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
+ device (`str`, *optional*, defaults to `cpu`):
+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -111,7 +116,8 @@ def export_pytorch(
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -136,6 +142,10 @@ def export_pytorch(
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
+ device = torch.device(device)
+ if device.type == "cuda" and torch.cuda.is_available():
+ model.to(device)
+ model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
@@ -146,7 +156,7 @@ def export_pytorch(
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
- if parse(torch.__version__) < parse("1.10"):
+ if is_torch_less_than_1_11:
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
try:
@@ -168,9 +178,13 @@ def export_pytorch(
message = str(err)
if (
message
- == "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter."
+ == "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without"
+ " setting use_external_data_format parameter."
):
- message = "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter or try with torch 1.10+."
+ message = (
+ "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export"
+ " without setting use_external_data_format parameter or try with torch 1.10+."
+ )
raise RuntimeError(message)
else:
raise err
@@ -227,7 +241,8 @@ def export_tensorflow(
raise ValueError("You cannot provide both a tokenizer and preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -256,18 +271,19 @@ def export_tensorflow(
def export(
- preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
+ preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: OnnxConfig,
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
+ device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
Args:
- preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
+ preprocessor: ([`PreTrainedTokenizer`], [`FeatureExtractionMixin`] or [`ProcessorMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
@@ -277,6 +293,9 @@ def export(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
+ device (`str`, *optional*, defaults to `cpu`):
+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
+ export on CUDA devices.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -288,11 +307,15 @@ def export(
"Please install torch or tensorflow first."
)
+ if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda":
+ raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
+
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -306,18 +329,19 @@ def export(
if not config.is_torch_support_available:
logger.warning(
- f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}, got: {torch_version}"
+ f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
+ f" got: {torch_version}"
)
if is_torch_available() and issubclass(type(model), PreTrainedModel):
- return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer)
+ return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)
def validate_model_outputs(
config: OnnxConfig,
- preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"],
+ preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin", "ProcessorMixin"],
reference_model: Union["PreTrainedModel", "TFPreTrainedModel"],
onnx_model: Path,
onnx_named_outputs: List[str],
@@ -329,10 +353,11 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
- raise ValueError("You cannot provide both a tokenizer and a preprocessor to validatethe model outputs.")
+ raise ValueError("You cannot provide both a tokenizer and a preprocessor to validate the model outputs.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -350,6 +375,8 @@ def validate_model_outputs(
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])
# Compute outputs from the reference model
+ if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
+ reference_model.to("cpu")
ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}
diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py
index 50e941332a95..8d8b8190e468 100644
--- a/src/transformers/onnx/features.py
+++ b/src/transformers/onnx/features.py
@@ -1,38 +1,17 @@
from functools import partial, reduce
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
-from ..models.albert import AlbertOnnxConfig
-from ..models.bart import BartOnnxConfig
-from ..models.beit import BeitOnnxConfig
-from ..models.bert import BertOnnxConfig
-from ..models.big_bird import BigBirdOnnxConfig
-from ..models.blenderbot import BlenderbotOnnxConfig
-from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
-from ..models.camembert import CamembertOnnxConfig
-from ..models.convbert import ConvBertOnnxConfig
-from ..models.data2vec import Data2VecTextOnnxConfig
-from ..models.deit import DeiTOnnxConfig
-from ..models.distilbert import DistilBertOnnxConfig
-from ..models.electra import ElectraOnnxConfig
-from ..models.flaubert import FlaubertOnnxConfig
-from ..models.gpt2 import GPT2OnnxConfig
-from ..models.gpt_neo import GPTNeoOnnxConfig
-from ..models.gptj import GPTJOnnxConfig
-from ..models.ibert import IBertOnnxConfig
-from ..models.layoutlm import LayoutLMOnnxConfig
-from ..models.m2m_100 import M2M100OnnxConfig
-from ..models.marian import MarianOnnxConfig
-from ..models.mbart import MBartOnnxConfig
-from ..models.roberta import RobertaOnnxConfig
-from ..models.roformer import RoFormerOnnxConfig
-from ..models.t5 import T5OnnxConfig
-from ..models.vit import ViTOnnxConfig
-from ..models.xlm_roberta import XLMRobertaOnnxConfig
+from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union
+
+import transformers
+
+from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from .config import OnnxConfig
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, TFPreTrainedModel
+
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_available():
@@ -40,9 +19,11 @@
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
+ AutoModelForImageSegmentation,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
+ AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
@@ -61,19 +42,20 @@
)
if not is_torch_available() and not is_tf_available():
logger.warning(
- "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
+ "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models"
+ " without one of these libraries installed."
)
def supported_features_mapping(
- *supported_features: str, onnx_config_cls: Type[OnnxConfig] = None
+ *supported_features: str, onnx_config_cls: str = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
- onnx_config_cls: The OnnxConfig class corresponding to the model.
+ onnx_config_cls: The OnnxConfig full name corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
@@ -81,13 +63,16 @@ def supported_features_mapping(
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
+ config_cls = transformers
+ for attr_name in onnx_config_cls.split("."):
+ config_cls = getattr(config_cls, attr_name)
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
- mapping[feature] = partial(onnx_config_cls.with_past, task=task)
+ mapping[feature] = partial(config_cls.with_past, task=task)
else:
- mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
+ mapping[feature] = partial(config_cls.from_model_config, task=feature)
return mapping
@@ -104,8 +89,10 @@ class FeaturesManager:
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
+ "object-detection": AutoModelForObjectDetection,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
+ "image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
}
if is_tf_available():
@@ -129,7 +116,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=AlbertOnnxConfig,
+ onnx_config_cls="models.albert.AlbertOnnxConfig",
),
"bart": supported_features_mapping(
"default",
@@ -140,10 +127,12 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
- onnx_config_cls=BartOnnxConfig,
+ onnx_config_cls="models.bart.BartOnnxConfig",
),
# BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
- "beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
+ "beit": supported_features_mapping(
+ "default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
+ ),
"bert": supported_features_mapping(
"default",
"masked-lm",
@@ -152,7 +141,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=BertOnnxConfig,
+ onnx_config_cls="models.bert.BertOnnxConfig",
),
"big-bird": supported_features_mapping(
"default",
@@ -162,7 +151,18 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=BigBirdOnnxConfig,
+ onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
+ ),
+ "bigbird-pegasus": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ "sequence-classification",
+ "question-answering",
+ onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
),
"blenderbot": supported_features_mapping(
"default",
@@ -171,7 +171,7 @@ class FeaturesManager:
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
- onnx_config_cls=BlenderbotOnnxConfig,
+ onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
),
"blenderbot-small": supported_features_mapping(
"default",
@@ -180,7 +180,16 @@ class FeaturesManager:
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
- onnx_config_cls=BlenderbotSmallOnnxConfig,
+ onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
+ ),
+ "bloom": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ "sequence-classification",
+ "token-classification",
+ onnx_config_cls="models.bloom.BloomOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
@@ -190,7 +199,12 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=CamembertOnnxConfig,
+ onnx_config_cls="models.camembert.CamembertOnnxConfig",
+ ),
+ "codegen": supported_features_mapping(
+ "default",
+ "causal-lm",
+ onnx_config_cls="models.codegen.CodeGenOnnxConfig",
),
"convbert": supported_features_mapping(
"default",
@@ -199,7 +213,12 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=ConvBertOnnxConfig,
+ onnx_config_cls="models.convbert.ConvBertOnnxConfig",
+ ),
+ "convnext": supported_features_mapping(
+ "default",
+ "image-classification",
+ onnx_config_cls="models.convnext.ConvNextOnnxConfig",
),
"data2vec-text": supported_features_mapping(
"default",
@@ -208,10 +227,39 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=Data2VecTextOnnxConfig,
+ onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
+ ),
+ "data2vec-vision": supported_features_mapping(
+ "default",
+ "image-classification",
+ "image-segmentation",
+ onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
+ ),
+ "deberta": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.deberta.DebertaOnnxConfig",
+ ),
+ "deberta-v2": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
),
"deit": supported_features_mapping(
- "default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
+ "default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
+ ),
+ "detr": supported_features_mapping(
+ "default",
+ "object-detection",
+ "image-segmentation",
+ onnx_config_cls="models.detr.DetrOnnxConfig",
),
"distilbert": supported_features_mapping(
"default",
@@ -220,7 +268,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=DistilBertOnnxConfig,
+ onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
),
"electra": supported_features_mapping(
"default",
@@ -230,7 +278,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=ElectraOnnxConfig,
+ onnx_config_cls="models.electra.ElectraOnnxConfig",
),
"flaubert": supported_features_mapping(
"default",
@@ -240,7 +288,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=FlaubertOnnxConfig,
+ onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
),
"gpt2": supported_features_mapping(
"default",
@@ -249,7 +297,7 @@ class FeaturesManager:
"causal-lm-with-past",
"sequence-classification",
"token-classification",
- onnx_config_cls=GPT2OnnxConfig,
+ onnx_config_cls="models.gpt2.GPT2OnnxConfig",
),
"gptj": supported_features_mapping(
"default",
@@ -258,7 +306,7 @@ class FeaturesManager:
"causal-lm-with-past",
"question-answering",
"sequence-classification",
- onnx_config_cls=GPTJOnnxConfig,
+ onnx_config_cls="models.gptj.GPTJOnnxConfig",
),
"gpt-neo": supported_features_mapping(
"default",
@@ -266,7 +314,7 @@ class FeaturesManager:
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
- onnx_config_cls=GPTNeoOnnxConfig,
+ onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
),
"ibert": supported_features_mapping(
"default",
@@ -275,14 +323,31 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=IBertOnnxConfig,
+ onnx_config_cls="models.ibert.IBertOnnxConfig",
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
- onnx_config_cls=LayoutLMOnnxConfig,
+ onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
+ ),
+ "layoutlmv3": supported_features_mapping(
+ "default",
+ "question-answering",
+ "sequence-classification",
+ "token-classification",
+ onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
+ ),
+ "levit": supported_features_mapping(
+ "default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
+ ),
+ "longt5": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.longt5.LongT5OnnxConfig",
),
"marian": supported_features_mapping(
"default",
@@ -291,7 +356,7 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
- onnx_config_cls=MarianOnnxConfig,
+ onnx_config_cls="models.marian.MarianOnnxConfig",
),
"mbart": supported_features_mapping(
"default",
@@ -302,10 +367,46 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
- onnx_config_cls=MBartOnnxConfig,
+ onnx_config_cls="models.mbart.MBartOnnxConfig",
+ ),
+ "mobilebert": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
+ ),
+ "mobilevit": supported_features_mapping(
+ "default",
+ "image-classification",
+ onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
+ ),
+ "mt5": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.mt5.MT5OnnxConfig",
),
"m2m-100": supported_features_mapping(
- "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
+ ),
+ "perceiver": supported_features_mapping(
+ "image-classification",
+ "masked-lm",
+ "sequence-classification",
+ onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
+ ),
+ "resnet": supported_features_mapping(
+ "default",
+ "image-classification",
+ onnx_config_cls="models.resnet.ResNetOnnxConfig",
),
"roberta": supported_features_mapping(
"default",
@@ -315,7 +416,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=RobertaOnnxConfig,
+ onnx_config_cls="models.roberta.RobertaOnnxConfig",
),
"roformer": supported_features_mapping(
"default",
@@ -326,13 +427,36 @@ class FeaturesManager:
"multiple-choice",
"question-answering",
"token-classification",
- onnx_config_cls=RoFormerOnnxConfig,
+ onnx_config_cls="models.roformer.RoFormerOnnxConfig",
+ ),
+ "squeezebert": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
),
"t5": supported_features_mapping(
- "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.t5.T5OnnxConfig",
),
"vit": supported_features_mapping(
- "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
+ "default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
+ ),
+ "xlm": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "causal-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.xlm.XLMOnnxConfig",
),
"xlm-roberta": supported_features_mapping(
"default",
@@ -342,7 +466,12 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=XLMRobertaOnnxConfig,
+ onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
+ ),
+ "yolos": supported_features_mapping(
+ "default",
+ "object-detection",
+ onnx_config_cls="models.yolos.YolosOnnxConfig",
),
}
@@ -415,15 +544,14 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError(
- f"Unknown task: {feature}. "
- f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
+ f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return task_to_automodel[task]
@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
- ) -> Union[PreTrainedModel, TFPreTrainedModel]:
+ ) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
@@ -451,7 +579,7 @@ def get_model_from_feature(
@staticmethod
def check_supported_model_or_raise(
- model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
+ model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
) -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features.
@@ -469,8 +597,22 @@ def check_supported_model_or_raise(
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
if feature not in model_features:
raise ValueError(
- f"{model.config.model_type} doesn't support feature {feature}. "
- f"Supported values are: {model_features}"
+ f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
+
+ def get_config(model_type: str, feature: str) -> OnnxConfig:
+ """
+ Gets the OnnxConfig for a model_type and feature combination.
+
+ Args:
+ model_type (`str`):
+ The model type to retrieve the config for.
+ feature (`str`):
+ The feature to retrieve the config for.
+
+ Returns:
+ `OnnxConfig`: config for the combination
+ """
+ return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
diff --git a/src/transformers/onnx/utils.py b/src/transformers/onnx/utils.py
index def160e6c7bb..9672b0a96af8 100644
--- a/src/transformers/onnx/utils.py
+++ b/src/transformers/onnx/utils.py
@@ -14,6 +14,11 @@
from ctypes import c_float, sizeof
from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+
+if TYPE_CHECKING:
+ from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
class ParameterFormat(Enum):
@@ -61,3 +66,44 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
+
+
+def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]:
+ """
+ Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
+
+ Args:
+ model_name (`str`): Name of the model for which a preprocessor are loaded.
+
+ Returns:
+ `Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]`:
+ If a processor is found, it is returned. Otherwise, if a tokenizer or a feature extractor exists, it is
+ returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
+ `None` if no preprocessor is found.
+ """
+ # Avoid circular imports by only importing this here.
+ from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
+
+ try:
+ return AutoProcessor.from_pretrained(model_name)
+ except (ValueError, OSError, KeyError):
+ tokenizer, feature_extractor = None, None
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ except (OSError, KeyError):
+ pass
+ try:
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
+ except (OSError, KeyError):
+ pass
+
+ if tokenizer is not None and feature_extractor is not None:
+ raise ValueError(
+ f"Couldn't auto-detect preprocessor for {model_name}. Found both a tokenizer and a feature extractor."
+ )
+ elif tokenizer is None and feature_extractor is None:
+ return None
+ elif tokenizer is not None:
+ return tokenizer
+ else:
+ return feature_extractor
diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py
index 60b9dca7831b..b957acb6de93 100644
--- a/src/transformers/optimization.py
+++ b/src/transformers/optimization.py
@@ -304,8 +304,9 @@ def __init__(
):
if not no_deprecation_warning:
warnings.warn(
- "This implementation of AdamW is deprecated and will be removed in a future version. Use the"
- " PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning",
+ "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
+ " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
+ " warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py
index 1350669e4516..dfa75768d8f8 100755
--- a/src/transformers/pipelines/__init__.py
+++ b/src/transformers/pipelines/__init__.py
@@ -23,13 +23,19 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from numpy import isin
+
+from huggingface_hub.file_download import http_get
+
from ..configuration_utils import PretrainedConfig
+from ..dynamic_module_utils import get_class_from_dynamic_module
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
-from ..utils import http_get, is_tf_available, is_torch_available, logging
+from ..tokenization_utils_fast import PreTrainedTokenizerFast
+from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_tf_available, is_torch_available, logging
from .audio_classification import AudioClassificationPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .base import (
@@ -40,7 +46,8 @@
Pipeline,
PipelineDataFormat,
PipelineException,
- get_default_model,
+ PipelineRegistry,
+ get_default_model_and_revision,
infer_framework_load_model,
)
from .conversational import Conversation, ConversationalPipeline
@@ -60,6 +67,7 @@
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
+from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
@@ -75,6 +83,7 @@
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForCausalLM,
+ TFAutoModelForImageClassification,
TFAutoModelForMaskedLM,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
@@ -93,6 +102,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
@@ -108,6 +118,7 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
+ AutoModelForVisualQuestionAnswering,
)
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
@@ -120,27 +131,28 @@
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
+ "vqa": "visual-question-answering",
}
SUPPORTED_TASKS = {
"audio-classification": {
"impl": AudioClassificationPipeline,
"tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
- "default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
+ "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
"type": "audio",
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
- "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
+ "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
"type": "multimodal",
},
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
- "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
+ "default": {"model": {"pt": ("distilbert-base-cased", "935ac13"), "tf": ("distilbert-base-cased", "935ac13")}},
"type": "multimodal",
},
"text-classification": {
@@ -149,8 +161,8 @@
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
- "pt": "distilbert-base-uncased-finetuned-sst-2-english",
- "tf": "distilbert-base-uncased-finetuned-sst-2-english",
+ "pt": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
+ "tf": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
},
},
"type": "text",
@@ -161,8 +173,8 @@
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": {
"model": {
- "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
- "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
+ "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
+ "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
},
},
"type": "text",
@@ -172,7 +184,10 @@
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": {
- "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
+ "model": {
+ "pt": ("distilbert-base-cased-distilled-squad", "626af31"),
+ "tf": ("distilbert-base-cased-distilled-squad", "626af31"),
+ },
},
"type": "text",
},
@@ -182,25 +197,33 @@
"tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
"default": {
"model": {
- "pt": "google/tapas-base-finetuned-wtq",
- "tokenizer": "google/tapas-base-finetuned-wtq",
- "tf": "google/tapas-base-finetuned-wtq",
+ "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),
+ "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),
},
},
"type": "text",
},
+ "visual-question-answering": {
+ "impl": VisualQuestionAnsweringPipeline,
+ "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
+ "tf": (),
+ "default": {
+ "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},
+ },
+ "type": "multimodal",
+ },
"fill-mask": {
"impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
- "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
+ "default": {"model": {"pt": ("distilroberta-base", "ec58a5b"), "tf": ("distilroberta-base", "ec58a5b")}},
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
- "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
+ "default": {"model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("t5-small", "d769bba")}},
"type": "text",
},
# This task is a special case as it's parametrized by SRC, TGT languages.
@@ -209,9 +232,9 @@
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
- ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
- ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
- ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
+ ("en", "fr"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
+ ("en", "de"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
+ ("en", "ro"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
},
"type": "text",
},
@@ -219,14 +242,14 @@
"impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
- "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
+ "default": {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
"type": "text",
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
- "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
+ "default": {"model": {"pt": ("gpt2", "6c0e608"), "tf": ("gpt2", "6c0e608")}},
"type": "text",
},
"zero-shot-classification": {
@@ -234,9 +257,8 @@
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
- "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
- "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
- "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
+ "model": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
+ "config": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
},
"type": "text",
},
@@ -244,41 +266,58 @@
"impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
- "default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}},
+ "default": {
+ "model": {
+ "pt": ("openai/clip-vit-base-patch32", "f4881ba"),
+ "tf": ("openai/clip-vit-base-patch32", "f4881ba"),
+ }
+ },
"type": "multimodal",
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
- "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
+ "default": {
+ "model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}
+ },
"type": "text",
},
"image-classification": {
"impl": ImageClassificationPipeline,
- "tf": (),
+ "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
- "default": {"model": {"pt": "google/vit-base-patch16-224"}},
+ "default": {
+ "model": {
+ "pt": ("google/vit-base-patch16-224", "5dca96d"),
+ "tf": ("google/vit-base-patch16-224", "5dca96d"),
+ }
+ },
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"tf": (),
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
- "default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
+ "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
"type": "image",
},
"object-detection": {
"impl": ObjectDetectionPipeline,
"tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
- "default": {"model": {"pt": "facebook/detr-resnet-50"}},
+ "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
"type": "image",
},
}
NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
+# Those model configs are special, they are generic over their task, meaning
+# any tokenizer/feature_extractor might be use for a given model so we cannot
+# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
+# see if the model defines such objects or not.
+MULTI_MODEL_CONFIGS = {"VisionTextDualEncoderConfig", "SpeechEncoderDecoderConfig"}
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
@@ -287,14 +326,14 @@
elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
+PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
+
def get_supported_tasks() -> List[str]:
"""
Returns a list of supported task strings.
"""
- supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
- supported_tasks.sort()
- return supported_tasks
+ return PIPELINE_REGISTRY.get_supported_tasks()
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
@@ -348,37 +387,44 @@ def check_task(task: str) -> Tuple[Dict, Any]:
- `"zero-shot-image-classification"`
Returns:
- (task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline
- and some extra task options for parametrized tasks like "translation_XX_to_YY"
+ (normalized_task: `str`, task_defaults: `dict`, task_options: (`tuple`, None)) The normalized task name
+ (removed alias and options). The actual dictionary required to initialize the pipeline and some extra task
+ options for parametrized tasks like "translation_XX_to_YY"
"""
- if task in TASK_ALIASES:
- task = TASK_ALIASES[task]
- if task in SUPPORTED_TASKS:
- targeted_task = SUPPORTED_TASKS[task]
- return targeted_task, None
+ return PIPELINE_REGISTRY.check_task(task)
+
- if task.startswith("translation"):
- tokens = task.split("_")
- if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
- targeted_task = SUPPORTED_TASKS["translation"]
- return targeted_task, (tokens[1], tokens[3])
- raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
+def clean_custom_task(task_info):
+ import transformers
- raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
+ if "impl" not in task_info:
+ raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
+ pt_class_names = task_info.get("pt", ())
+ if isinstance(pt_class_names, str):
+ pt_class_names = [pt_class_names]
+ task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
+ tf_class_names = task_info.get("tf", ())
+ if isinstance(tf_class_names, str):
+ tf_class_names = [tf_class_names]
+ task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
+ return task_info, None
def pipeline(
task: str = None,
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
- tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
+ tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
+ device_map=None,
+ torch_dtype=None,
+ trust_remote_code: Optional[bool] = None,
model_kwargs: Dict[str, Any] = None,
pipeline_class: Optional[Any] = None,
**kwargs
@@ -461,7 +507,25 @@ def pipeline(
Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`).
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
+ Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
+ `device_map="auto"` to compute the most optimized `device_map` automatically. [More
+ information](https://huggingface.co/docs/accelerate/main/en/big_modeling#accelerate.cpu_offload)
+
+
+
+ Do not use `device_map` AND `device` at the same time as they will conflict
+
+
+
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
+ (`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
+ tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
+ and in which you have read the code, as it will execute code present on the Hub on your local machine.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
**model_kwargs)` function.
@@ -490,6 +554,10 @@ def pipeline(
```"""
if model_kwargs is None:
model_kwargs = {}
+ # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
+ # this is to keep BC).
+ use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
+ hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
if task is None and model is None:
raise RuntimeError(
@@ -500,17 +568,36 @@ def pipeline(
if model is None and tokenizer is not None:
raise RuntimeError(
- "Impossible to instantiate a pipeline with tokenizer specified but not the model "
- "as the provided tokenizer may not be compatible with the default model. "
- "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing tokenizer."
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
+ " may not be compatible with the default model. Please provide a PreTrainedModel class or a"
+ " path/identifier to a pretrained model when providing tokenizer."
)
if model is None and feature_extractor is not None:
raise RuntimeError(
- "Impossible to instantiate a pipeline with feature_extractor specified but not the model "
- "as the provided feature_extractor may not be compatible with the default model. "
- "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided"
+ " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
+ " or a path/identifier to a pretrained model when providing feature_extractor."
)
+ # Config is the primordial information item.
+ # Instantiate config if needed
+ if isinstance(config, str):
+ config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
+ elif config is None and isinstance(model, str):
+ config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
+
+ custom_tasks = {}
+ if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
+ custom_tasks = config.custom_pipelines
+ if task is None and trust_remote_code is not False:
+ if len(custom_tasks) == 1:
+ task = list(custom_tasks.keys())[0]
+ else:
+ raise RuntimeError(
+ "We can't infer the task automatically for this model as there are multiple tasks available. Pick "
+ f"one in {', '.join(custom_tasks.keys())}"
+ )
+
if task is None and model is not None:
if not isinstance(model, str):
raise RuntimeError(
@@ -520,25 +607,53 @@ def pipeline(
task = get_task(model, use_auth_token)
# Retrieve the task
- targeted_task, task_options = check_task(task)
- if pipeline_class is None:
- pipeline_class = targeted_task["impl"]
+ if task in custom_tasks:
+ normalized_task = task
+ targeted_task, task_options = clean_custom_task(custom_tasks[task])
+ if pipeline_class is None:
+ if not trust_remote_code:
+ raise ValueError(
+ "Loading this pipeline requires you to execute the code in the pipeline file in that"
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
+ " set the option `trust_remote_code=True` to remove this error."
+ )
+ class_ref = targeted_task["impl"]
+ module_file, class_name = class_ref.split(".")
+ pipeline_class = get_class_from_dynamic_module(
+ model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
+ )
+ else:
+ normalized_task, targeted_task, task_options = check_task(task)
+ if pipeline_class is None:
+ pipeline_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
# At that point framework might still be undetermined
- model = get_default_model(targeted_task, framework, task_options)
- logger.warning(f"No model was supplied, defaulted to {model} (https://huggingface.co/{model})")
-
- # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
- model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
-
- # Config is the primordial information item.
- # Instantiate config if needed
- if isinstance(config, str):
- config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
- elif config is None and isinstance(model, str):
- config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
+ model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
+ revision = revision if revision is not None else default_revision
+ logger.warning(
+ f"No model was supplied, defaulted to {model} and revision"
+ f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
+ "Using a pipeline without specifying a model name and revision in production is not recommended."
+ )
+ if config is None and isinstance(model, str):
+ config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
+
+ if device_map is not None:
+ if "device_map" in model_kwargs:
+ raise ValueError(
+ 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
+ " arguments might conflict, use only one.)"
+ )
+ model_kwargs["device_map"] = device_map
+ if torch_dtype is not None:
+ if "torch_dtype" in model_kwargs:
+ raise ValueError(
+ 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
+ " arguments might conflict, use only one.)"
+ )
+ model_kwargs["torch_dtype"] = torch_dtype
model_name = model if isinstance(model, str) else None
@@ -551,8 +666,8 @@ def pipeline(
model_classes=model_classes,
config=config,
framework=framework,
- revision=revision,
task=task,
+ **hub_kwargs,
**model_kwargs,
)
@@ -561,12 +676,36 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
+ if (
+ tokenizer is None
+ and not load_tokenizer
+ and normalized_task not in NO_TOKENIZER_TASKS
+ # Using class name to avoid importing the real class.
+ and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
+ ):
+ # This is a special category of models, that are fusions of multiple models
+ # so the model_config might not define a tokenizer, but it seems to be
+ # necessary for the task, so we're force-trying to load it.
+ load_tokenizer = True
+ if (
+ feature_extractor is None
+ and not load_feature_extractor
+ and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
+ # Using class name to avoid importing the real class.
+ and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
+ ):
+ # This is a special category of models, that are fusions of multiple models
+ # so the model_config might not define a tokenizer, but it seems to be
+ # necessary for the task, so we're force-trying to load it.
+ load_feature_extractor = True
+
if task in NO_TOKENIZER_TASKS:
# These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it.
load_tokenizer = False
+
if task in NO_FEATURE_EXTRACTOR_TASKS:
load_feature_extractor = False
@@ -596,7 +735,7 @@ def pipeline(
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
+ tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
)
if load_feature_extractor:
@@ -617,7 +756,7 @@ def pipeline(
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
- feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
+ feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
)
if (
@@ -642,7 +781,9 @@ def pipeline(
kwargs["decoder"] = decoder
except ImportError as e:
logger.warning(
- f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
+ f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install"
+ " `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install"
+ f" https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)
if task == "translation" and model.config.task_specific_params:
diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py
index e10dc1208e45..c52b1002cf71 100644
--- a/src/transformers/pipelines/automatic_speech_recognition.py
+++ b/src/transformers/pipelines/automatic_speech_recognition.py
@@ -168,7 +168,7 @@ def __call__(
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
- predicts that the word "hi" was pronounces before 0.5 and after 0.9 seconds.
+ predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
Return:
`Dict`: A dictionary with the following keys:
diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index d54a17df1e9d..6e2c28e5ddf8 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -29,6 +29,7 @@
from packaging import version
+from ..dynamic_module_utils import custom_object_save
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
@@ -75,14 +76,19 @@ def _pad(items, key, padding_value, padding_side):
# Others include `attention_mask` etc...
shape = items[0][key].shape
dim = len(shape)
- if dim == 4:
+ if key == "pixel_values":
# This is probable image so padding shouldn't be necessary
# B, C, H, W
return torch.cat([item[key] for item in items], dim=0)
max_length = max(item[key].shape[1] for item in items)
+ min_length = min(item[key].shape[1] for item in items)
dtype = items[0][key].dtype
if dim == 2:
+ if max_length == min_length:
+ # Bypass for `ImageGPT` which doesn't provide a padding value, yet
+ # we can consistently pad since the size should be matching
+ return torch.cat([item[key] for item in items], dim=0)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
elif dim == 3:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
@@ -139,13 +145,18 @@ def inner(items):
for item in items:
if set(item.keys()) != keys:
raise ValueError(
- f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} != {keys})"
+ f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !="
+ f" {keys})"
)
# input_values, input_pixels, input_ids, ...
padded = {}
for key in keys:
if key in {"input_ids"}:
- _padding_value = t_padding_value
+ # ImageGPT uses a feature extractor
+ if feature_extractor is not None:
+ _padding_value = f_padding_value
+ else:
+ _padding_value = t_padding_value
elif key in {"input_values", "pixel_values", "input_features"}:
_padding_value = f_padding_value
elif key in {"p_mask", "special_tokens_mask"}:
@@ -331,7 +342,9 @@ def get_framework(model, revision: Optional[str] = None):
return framework
-def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str:
+def get_default_model_and_revision(
+ targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
+) -> Union[str, Tuple[str, str]]:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
@@ -617,7 +630,6 @@ def __iter__(self):
for line in sys.stdin:
# Split for multi-columns
if "\t" in line:
-
line = line.split("\t")
if self.column:
# Dictionary to map arguments
@@ -692,7 +704,7 @@ def predict(self, X):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
- the associated CUDA device id.
+ the associated CUDA device id. You can pass native `torch.device` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
@@ -739,7 +751,6 @@ def __init__(
binary_output: bool = False,
**kwargs,
):
-
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
@@ -749,7 +760,10 @@ def __init__(
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
- self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
+ if is_torch_available() and isinstance(device, torch.device):
+ self.device = device
+ else:
+ self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output
# Special handling
@@ -779,6 +793,27 @@ def save_pretrained(self, save_directory: str):
return
os.makedirs(save_directory, exist_ok=True)
+ if hasattr(self, "_registered_impl"):
+ # Add info to the config
+ pipeline_info = self._registered_impl.copy()
+ custom_pipelines = {}
+ for task, info in pipeline_info.items():
+ if info["impl"] != self.__class__:
+ continue
+
+ info = info.copy()
+ module_name = info["impl"].__module__
+ last_module = module_name.split(".")[-1]
+ # Change classes into their names/full names
+ info["impl"] = f"{last_module}.{info['impl'].__name__}"
+ info["pt"] = tuple(c.__name__ for c in info["pt"])
+ info["tf"] = tuple(c.__name__ for c in info["tf"])
+
+ custom_pipelines[task] = info
+ self.model.config.custom_pipelines = custom_pipelines
+ # Save the pipeline custom code
+ custom_object_save(self, save_directory)
+
self.model.save_pretrained(save_directory)
if self.tokenizer is not None:
@@ -856,6 +891,8 @@ def _ensure_tensor_on_device(self, inputs, device):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
+ if device == torch.device("cpu") and inputs.dtype in {torch.float16, torch.bfloat16}:
+ inputs = inputs.float()
return inputs.to(device)
else:
return inputs
@@ -879,7 +916,8 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
logger.error(
- f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}."
+ f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
+ f" {supported_models}."
)
@abstractmethod
@@ -927,7 +965,9 @@ def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict
def get_inference_context(self):
inference_context = (
- torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
+ torch.inference_mode
+ if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.9.0")
+ else torch.no_grad
)
return inference_context
@@ -994,7 +1034,8 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs):
self.call_count += 1
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
warnings.warn(
- "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
+ "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
+ " dataset",
UserWarning,
)
@@ -1058,7 +1099,8 @@ def get_iterator(
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if num_workers > 1:
logger.warning(
- "For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable, setting `num_workers=1` to guarantee correctness."
+ "For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable,"
+ " setting `num_workers=1` to guarantee correctness."
)
num_workers = 1
dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
@@ -1067,3 +1109,71 @@ def get_iterator(
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator
+
+
+class PipelineRegistry:
+ def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
+ self.supported_tasks = supported_tasks
+ self.task_aliases = task_aliases
+
+ def get_supported_tasks(self) -> List[str]:
+ supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
+ supported_task.sort()
+ return supported_task
+
+ def check_task(self, task: str) -> Tuple[str, Dict, Any]:
+ if task in self.task_aliases:
+ task = self.task_aliases[task]
+ if task in self.supported_tasks:
+ targeted_task = self.supported_tasks[task]
+ return task, targeted_task, None
+
+ if task.startswith("translation"):
+ tokens = task.split("_")
+ if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
+ targeted_task = self.supported_tasks["translation"]
+ task = "translation"
+ return task, targeted_task, (tokens[1], tokens[3])
+ raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
+
+ raise KeyError(
+ f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
+ )
+
+ def register_pipeline(
+ self,
+ task: str,
+ pipeline_class: type,
+ pt_model: Optional[Union[type, Tuple[type]]] = None,
+ tf_model: Optional[Union[type, Tuple[type]]] = None,
+ default: Optional[Dict] = None,
+ type: Optional[str] = None,
+ ) -> None:
+ if task in self.supported_tasks:
+ logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
+
+ if pt_model is None:
+ pt_model = ()
+ elif not isinstance(pt_model, tuple):
+ pt_model = (pt_model,)
+
+ if tf_model is None:
+ tf_model = ()
+ elif not isinstance(tf_model, tuple):
+ tf_model = (tf_model,)
+
+ task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model}
+
+ if default is not None:
+ if "model" not in default and ("pt" in default or "tf" in default):
+ default = {"model": default}
+ task_impl["default"] = default
+
+ if type is not None:
+ task_impl["type"] = type
+
+ self.supported_tasks[task] = task_impl
+ pipeline_class._registered_impl = {task: task_impl}
+
+ def to_dict(self):
+ return self.supported_tasks
diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py
index 517b457a654b..f461f6faa2af 100644
--- a/src/transformers/pipelines/fill_mask.py
+++ b/src/transformers/pipelines/fill_mask.py
@@ -167,7 +167,7 @@ def get_target_ids(self, targets, top_k=None):
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
- f"We cannot replace it with anything meaningful, ignoring it"
+ "We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py
index c629f703a030..6f07382dc57c 100644
--- a/src/transformers/pipelines/question_answering.py
+++ b/src/transformers/pipelines/question_answering.py
@@ -1,3 +1,4 @@
+import types
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@@ -7,7 +8,14 @@
from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
-from ..utils import PaddingStrategy, add_end_docstrings, is_tf_available, is_torch_available, logging
+from ..utils import (
+ PaddingStrategy,
+ add_end_docstrings,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+ logging,
+)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
@@ -17,13 +25,19 @@
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
+ if is_tokenizers_available():
+ import tokenizers
+
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
+ Dataset = None
+
if is_torch_available():
import torch
+ from torch.utils.data import Dataset
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
@@ -82,6 +96,11 @@ def __call__(self, *args, **kwargs):
else:
raise ValueError(f"Unknown arguments {kwargs}")
+ # When user is sending a generator we need to trust it's a valid example
+ generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
+ if isinstance(inputs, generator_types):
+ return inputs
+
# Normalize inputs
if isinstance(inputs, dict):
inputs = [inputs]
@@ -171,6 +190,7 @@ def _sanitize_parameters(
max_seq_len=None,
max_question_len=None,
handle_impossible_answer=None,
+ align_to_words=None,
**kwargs
):
# Set defaults values
@@ -199,6 +219,8 @@ def _sanitize_parameters(
postprocess_params["max_answer_len"] = max_answer_len
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
+ if align_to_words is not None:
+ postprocess_params["align_to_words"] = align_to_words
return preprocess_params, {}, postprocess_params
def __call__(self, *args, **kwargs):
@@ -228,12 +250,15 @@ def __call__(self, *args, **kwargs):
max_answer_len (`int`, *optional*, defaults to 15):
The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
max_seq_len (`int`, *optional*, defaults to 384):
- The maximum length of the total sentence (context + question) after tokenization. The context will be
- split in several chunks (using `doc_stride`) if needed.
+ The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
+ model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
max_question_len (`int`, *optional*, defaults to 64):
The maximum length of the question after tokenization. It will be truncated if needed.
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
Whether or not we accept impossible as an answer.
+ align_to_words (`bool`, *optional*, defaults to `True`):
+ Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on
+ non-space-separated languages (like Japanese or Chinese)
Return:
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
@@ -245,12 +270,18 @@ def __call__(self, *args, **kwargs):
"""
# Convert inputs to features
+
examples = self._args_parser(*args, **kwargs)
- if len(examples) == 1:
+ if isinstance(examples, (list, tuple)) and len(examples) == 1:
return super().__call__(examples[0], **kwargs)
return super().__call__(examples, **kwargs)
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
+ # XXX: This is specal, args_parser will not handle anything generator or dataset like
+ # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
+ # So we still need a little sanitation here.
+ if isinstance(example, dict):
+ example = SquadExample(None, example["question"], example["context"], None, None, None)
if max_seq_len is None:
max_seq_len = min(self.tokenizer.model_max_length, 384)
@@ -279,7 +310,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
truncation="only_second" if question_first else "only_first",
max_length=max_seq_len,
stride=doc_stride,
- return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
@@ -294,12 +324,10 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
- p_mask = np.asarray(
- [
- [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
- for span_id in range(num_spans)
- ]
- )
+ p_mask = [
+ [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
+ for span_id in range(num_spans)
+ ]
features = []
for span_idx in range(num_spans):
@@ -316,8 +344,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
for cls_index in cls_indices:
p_mask[span_idx][cls_index] = 0
submask = p_mask[span_idx]
- if isinstance(submask, np.ndarray):
- submask = submask.tolist()
features.append(
SquadFeatures(
input_ids=input_ids_span_idx,
@@ -344,7 +370,7 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
for i, feature in enumerate(features):
fw_args = {}
others = {}
- model_input_names = self.tokenizer.model_input_names + ["p_mask"]
+ model_input_names = self.tokenizer.model_input_names + ["p_mask", "token_type_ids"]
for k, v in feature.__dict__.items():
if k in model_input_names:
@@ -376,6 +402,7 @@ def postprocess(
top_k=1,
handle_impossible_answer=False,
max_answer_len=15,
+ align_to_words=True,
):
min_null_score = 1000000 # large and positive
answers = []
@@ -398,8 +425,11 @@ def postprocess(
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
# Normalize logits and spans to retrieve the answer
- start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
- end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
+ start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True))
+ start_ = start_ / start_.sum()
+
+ end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True))
+ end_ = end_ / end_.sum()
if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item())
@@ -451,15 +481,8 @@ def postprocess(
for s, e, score in zip(starts, ends, scores):
s = s - offset
e = e - offset
- try:
- start_word = enc.token_to_word(s)
- end_word = enc.token_to_word(e)
- start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
- end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
- except Exception:
- # Some tokenizers don't really handle words. Keep to offsets then.
- start_index = enc.offsets[s][0]
- end_index = enc.offsets[e][1]
+
+ start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)
answers.append(
{
@@ -477,6 +500,24 @@ def postprocess(
return answers[0]
return answers
+ def get_indices(
+ self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
+ ) -> Tuple[int, int]:
+ if align_to_words:
+ try:
+ start_word = enc.token_to_word(s)
+ end_word = enc.token_to_word(e)
+ start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
+ end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
+ except Exception:
+ # Some tokenizers don't really handle words. Keep to offsets then.
+ start_index = enc.offsets[s][0]
+ end_index = enc.offsets[e][1]
+ else:
+ start_index = enc.offsets[s][0]
+ end_index = enc.offsets[e][1]
+ return start_index, end_index
+
def decode(
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
) -> Tuple:
diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py
index d94bb6d061ff..25dcd320cf4f 100644
--- a/src/transformers/pipelines/table_question_answering.py
+++ b/src/transformers/pipelines/table_question_answering.py
@@ -56,14 +56,14 @@ def __call__(self, table=None, query=None, **kwargs):
tqa_pipeline_inputs = table
else:
raise ValueError(
- f"If keyword argument `table` is a list of dictionaries, each dictionary should have a `table` "
- f"and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
+ "If keyword argument `table` is a list of dictionaries, each dictionary should have a `table`"
+ f" and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
)
elif Dataset is not None and isinstance(table, Dataset) or isinstance(table, types.GeneratorType):
return table
else:
raise ValueError(
- f"Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
+ "Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
f"is {type(table)})"
)
else:
diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py
index 3d3f4e533d45..dd8de4c7357f 100644
--- a/src/transformers/pipelines/text_classification.py
+++ b/src/transformers/pipelines/text_classification.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Dict
import numpy as np
@@ -72,15 +73,26 @@ def __init__(self, **kwargs):
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
)
- def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs):
+ def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
+ # Using "" as default argument because we're going to use `top_k=None` in user code to declare
+ # "No top_k"
preprocess_params = tokenizer_kwargs
postprocess_params = {}
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
return_all_scores = self.model.config.return_all_scores
- if return_all_scores is not None:
- postprocess_params["return_all_scores"] = return_all_scores
+ if isinstance(top_k, int) or top_k is None:
+ postprocess_params["top_k"] = top_k
+ postprocess_params["_legacy"] = False
+ elif return_all_scores is not None:
+ warnings.warn(
+ "`return_all_scores` is now deprecated, use `top_k=1` if you want similar functionnality", UserWarning
+ )
+ if return_all_scores:
+ postprocess_params["top_k"] = None
+ else:
+ postprocess_params["top_k"] = 1
if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()]
@@ -94,10 +106,11 @@ def __call__(self, *args, **kwargs):
Classify the text(s) given as inputs.
Args:
- args (`str` or `List[str]`):
- One or several texts (or one list of prompts) to classify.
- return_all_scores (`bool`, *optional*, defaults to `False`):
- Whether to return scores for all labels.
+ args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
+ One or several texts to classify. In order to use text pairs for your classification, you can send a
+ dictionnary containing `{"text", "text_pair"}` keys, or a list of those.
+ top_k (`int`, *optional*, defaults to `1`):
+ How many results to return.
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
values:
@@ -120,10 +133,12 @@ def __call__(self, *args, **kwargs):
- **label** (`str`) -- The label predicted.
- **score** (`float`) -- The corresponding probability.
- If `self.return_all_scores=True`, one such dictionary is returned per label.
+ If `top_k` is used, one such dictionary is returned per label.
"""
result = super().__call__(*args, **kwargs)
- if isinstance(args[0], str):
+ # TODO try and retrieve it in a nicer way from _sanitize_parameters.
+ _legacy = "top_k" not in kwargs
+ if isinstance(args[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run
return [result]
else:
@@ -131,12 +146,28 @@ def __call__(self, *args, **kwargs):
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
+ if isinstance(inputs, dict):
+ return self.tokenizer(**inputs, return_tensors=return_tensors, **tokenizer_kwargs)
+ elif isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], list) and len(inputs[0]) == 2:
+ # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
+ return self.tokenizer(
+ text=inputs[0][0], text_pair=inputs[0][1], return_tensors=return_tensors, **tokenizer_kwargs
+ )
+ elif isinstance(inputs, list):
+ # This is likely an invalid usage of the pipeline attempting to pass text pairs.
+ raise ValueError(
+ "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a"
+ ' dictionnary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
+ )
return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
def _forward(self, model_inputs):
return self.model(**model_inputs)
- def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False):
+ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
+ # `_legacy` is used to determine if we're running the naked pipeline and in backward
+ # compatibility mode, or if running the pipeline with `pipeline(..., top_k=1)` we're running
+ # the more natural result containing the list.
# Default value before `set_parameters`
if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
@@ -160,7 +191,14 @@ def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=F
else:
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
- if return_all_scores:
- return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
- else:
+ if top_k == 1 and _legacy:
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
+
+ dict_scores = [
+ {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
+ ]
+ if not _legacy:
+ dict_scores.sort(key=lambda x: x["score"], reverse=True)
+ if top_k is not None:
+ dict_scores = dict_scores[:top_k]
+ return dict_scores
diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py
index dbaa0a9df75a..7d15316492b9 100644
--- a/src/transformers/pipelines/text_generation.py
+++ b/src/transformers/pipelines/text_generation.py
@@ -103,7 +103,8 @@ def _sanitize_parameters(
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
- f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
+ f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
+ " [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
@@ -192,7 +193,8 @@ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **gene
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
- "We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
+ "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
+ " models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
@@ -203,14 +205,17 @@ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **gene
def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
+ attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
+ attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
- generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
+ # BS x SL
+ generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py
index 4ea8d114150d..04a80b32dd58 100644
--- a/src/transformers/pipelines/token_classification.py
+++ b/src/transformers/pipelines/token_classification.py
@@ -133,11 +133,13 @@ def _sanitize_parameters(
if grouped_entities is not None:
warnings.warn(
- f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
+ "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to"
+ f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
- f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
+ "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to"
+ f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if aggregation_strategy is not None:
@@ -289,7 +291,7 @@ def gather_pre_entities(
AggregationStrategy.MAX,
}:
warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning)
- is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False
+ is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py
new file mode 100644
index 000000000000..34a7a3b10d40
--- /dev/null
+++ b/src/transformers/pipelines/visual_question_answering.py
@@ -0,0 +1,115 @@
+from typing import Union
+
+from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
+from .base import PIPELINE_INIT_ARGS, Pipeline
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from ..image_utils import load_image
+
+if is_torch_available():
+ from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
+
+logger = logging.get_logger(__name__)
+
+
+@add_end_docstrings(PIPELINE_INIT_ARGS)
+class VisualQuestionAnsweringPipeline(Pipeline):
+ """
+ Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only
+ available in PyTorch.
+
+ This visual question answering pipeline can currently be loaded from [`pipeline`] using the following task
+ identifiers: `"visual-question-answering", "vqa"`.
+
+ The models that this pipeline can use are models that have been fine-tuned on a visual question answering task. See
+ the up-to-date list of available models on
+ [huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING)
+
+ def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
+ preprocess_params, postprocess_params = {}, {}
+ if padding is not None:
+ preprocess_params["padding"] = padding
+ if truncation is not None:
+ preprocess_params["truncation"] = truncation
+ if top_k is not None:
+ postprocess_params["top_k"] = top_k
+ return preprocess_params, {}, postprocess_params
+
+ def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs):
+ r"""
+ Answers open-ended questions about images. The pipeline accepts several types of inputs which are detailed
+ below:
+
+ - `pipeline(image=image, question=question)`
+ - `pipeline({"image": image, "question": question})`
+ - `pipeline([{"image": image, "question": question}])`
+ - `pipeline([{"image": image, "question": question}, {"image": image, "question": question}])`
+
+ Args:
+ image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
+ The pipeline handles three types of images:
+
+ - A string containing a http link pointing to an image
+ - A string containing a local path to an image
+ - An image loaded in PIL directly
+
+ The pipeline accepts either a single image or a batch of images. If given a single image, it can be
+ broadcasted to multiple questions.
+ question (`str`, `List[str]`):
+ The question(s) asked. If given a single question, it can be broadcasted to multiple images.
+ top_k (`int`, *optional*, defaults to 5):
+ The number of top labels that will be returned by the pipeline. If the provided number is higher than
+ the number of labels available in the model configuration, it will default to the number of labels.
+ Return:
+ A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
+
+ - **label** (`str`) -- The label identified by the model.
+ - **score** (`int`) -- The score attributed by the model for that label.
+ """
+ if isinstance(image, (Image.Image, str)) and isinstance(question, str):
+ inputs = {"image": image, "question": question}
+ else:
+ """
+ Supports the following format
+ - {"image": image, "question": question}
+ - [{"image": image, "question": question}]
+ - Generator and datasets
+ """
+ inputs = image
+ results = super().__call__(inputs, **kwargs)
+ return results
+
+ def preprocess(self, inputs, padding=False, truncation=False):
+ image = load_image(inputs["image"])
+ model_inputs = self.tokenizer(
+ inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
+ )
+ image_features = self.feature_extractor(images=image, return_tensors=self.framework)
+ model_inputs.update(image_features)
+ return model_inputs
+
+ def _forward(self, model_inputs):
+ model_outputs = self.model(**model_inputs)
+ return model_outputs
+
+ def postprocess(self, model_outputs, top_k=5):
+ if top_k > self.model.config.num_labels:
+ top_k = self.model.config.num_labels
+
+ if self.framework == "pt":
+ probs = model_outputs.logits.sigmoid()[0]
+ scores, ids = probs.topk(top_k)
+ else:
+ raise ValueError(f"Unsupported framework: {self.framework}")
+
+ scores = scores.tolist()
+ ids = ids.tolist()
+ return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py
index 9d5d5bd61b78..f98c87166ca0 100644
--- a/src/transformers/pipelines/zero_shot_classification.py
+++ b/src/transformers/pipelines/zero_shot_classification.py
@@ -86,7 +86,8 @@ def _parse_and_tokenize(
if self.tokenizer.pad_token is None:
# Override for tokenizers not supporting padding
logger.error(
- "Tokenizer was not supporting padding necessary for zero-shot, attempting to use `pad_token=eos_token`"
+ "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
+ " `pad_token=eos_token`"
)
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index 0a04813a3143..3915c3f8a5b7 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -109,24 +109,19 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your processor to the Hugging Face model hub after saving it.
-
-
-
- Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
- which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
- folder. Pass along `temp_dir=True` to use a temporary directory instead.
-
-
-
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
+ os.makedirs(save_directory, exist_ok=True)
+
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- repo = self._create_or_get_repo(save_directory, **kwargs)
-
- os.makedirs(save_directory, exist_ok=True)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id, token = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
@@ -150,8 +145,9 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
del attribute.init_kwargs["auto_map"]
if push_to_hub:
- url = self._push_to_hub(repo, commit_message=commit_message)
- logger.info(f"Processor pushed to the hub in this commit: {url}")
+ self._upload_modified_files(
+ save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
+ )
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
diff --git a/src/transformers/py.typed b/src/transformers/py.typed
deleted file mode 100644
index 8b137891791f..000000000000
--- a/src/transformers/py.typed
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index ddfadfcbc07f..571a5d7d3c94 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -21,10 +21,16 @@
from .utils import logging
+ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
+
logger = logging.get_logger(__name__)
-is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
-is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
+parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
+is_torch_greater_or_equal_than_1_6 = parsed_torch_version_base >= version.parse("1.6.0")
+is_torch_greater_than_1_6 = parsed_torch_version_base > version.parse("1.6.0")
+is_torch_less_than_1_8 = parsed_torch_version_base < version.parse("1.8.0")
+is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
+is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
def torch_int_div(tensor1, tensor2):
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 86d3673b7477..80f7bf9c863c 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
import contextlib
import inspect
import logging
@@ -19,6 +20,7 @@
import re
import shlex
import shutil
+import subprocess
import sys
import tempfile
import unittest
@@ -26,7 +28,7 @@
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
-from typing import Iterator, Union
+from typing import Iterator, List, Union
from unittest import mock
from transformers import logging as transformers_logging
@@ -40,12 +42,14 @@
is_wandb_available,
)
from .utils import (
+ is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
+ is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@@ -60,15 +64,19 @@
is_soundfile_availble,
is_spacy_available,
is_tensorflow_probability_available,
+ is_tensorflow_text_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
- is_torch_bf16_available,
+ is_torch_bf16_cpu_available,
+ is_torch_bf16_gpu_available,
+ is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
+ is_torchdynamo_available,
is_vision_available,
)
@@ -80,8 +88,10 @@
# Used to test the hub
USER = "__DUMMY_TRANSFORMERS_USER__"
-PASS = "__DUMMY_TRANSFORMERS_PASS__"
-ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
+ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
+
+# Not critical, only usable on the sandboxed CI instance.
+TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
def parse_flag_from_env(key, default=False):
@@ -238,6 +248,13 @@ def require_git_lfs(test_case):
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
@@ -273,6 +290,21 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+def require_intel_extension_for_pytorch(test_case):
+ """
+ Decorator marking a test that requires Intel Extension for PyTorch.
+
+ These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
+ version.
+
+ """
+ return unittest.skipUnless(
+ is_ipex_available(),
+ "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
+ " https://github.com/intel/intel-extension-for-pytorch",
+ )(test_case)
+
+
def require_torch_scatter(test_case):
"""
Decorator marking a test that requires PyTorch scatter.
@@ -337,6 +369,14 @@ def require_tokenizers(test_case):
return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
+def require_tensorflow_text(test_case):
+ """
+ Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
+ installed.
+ """
+ return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
+
+
def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
@@ -434,7 +474,7 @@ def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
- return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
+ return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
if is_torch_available():
@@ -456,15 +496,34 @@ def require_torch_tpu(test_case):
jax_device = None
+def require_torchdynamo(test_case):
+ """Decorator marking a test that requires TorchDynamo"""
+ return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
+
+
+def require_torch_tensorrt_fx(test_case):
+ """Decorator marking a test that requires Torch-TensorRT FX"""
+ return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
+
+
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
-def require_torch_bf16(test_case):
- """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
+def require_torch_bf16_gpu(test_case):
+ """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
+ return unittest.skipUnless(
+ is_torch_bf16_gpu_available(),
+ "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
+ )(test_case)
+
+
+def require_torch_bf16_cpu(test_case):
+ """Decorator marking a test that requires torch>=1.10, using CPU."""
return unittest.skipUnless(
- is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
+ is_torch_bf16_cpu_available(),
+ "test requires torch>=1.10, using CPU",
)(test_case)
@@ -1327,9 +1386,13 @@ def summary_failures_short(tr):
tr.summary_warnings() # final warnings
tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
- with open(report_files["passes"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_passes()
+
+ # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
+ # takes > 10 minutes (as this part doesn't generate any output on the terminal).
+ # (also, it seems there is no useful information in this report, and we rarely need to read it)
+ # with open(report_files["passes"], "w") as f:
+ # tr._tw = create_terminal_writer(config, f)
+ # tr.summary_passes()
with open(report_files["summary_short"], "w") as f:
tr._tw = create_terminal_writer(config, f)
@@ -1480,3 +1543,48 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
+
+
+def check_json_file_has_correct_format(file_path):
+ with open(file_path, "r") as f:
+ lines = f.readlines()
+ if len(lines) == 1:
+ # length can only be 1 if dict is empty
+ assert lines[0] == "{}"
+ else:
+ # otherwise make sure json has correct format (at least 3 lines)
+ assert len(lines) >= 3
+ # each key one line, ident should be 2, min length is 3
+ assert lines[0].strip() == "{"
+ for line in lines[1:-1]:
+ left_indent = len(lines[1]) - len(lines[1].lstrip())
+ assert left_indent == 2
+ assert lines[-1].strip() == "}"
+
+
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occured while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py
index 694b55cedd3e..6d33266c03f4 100644
--- a/src/transformers/tokenization_utils.py
+++ b/src/transformers/tokenization_utils.py
@@ -250,7 +250,8 @@ def cut_text(self, text, offsets):
for end in offsets:
if start > end:
logger.error(
- "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
+ "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
+ " anyway."
)
continue
elif start == end:
@@ -627,11 +628,13 @@ def get_input_ids(text):
else:
if is_split_into_words:
raise ValueError(
- f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
+ " `is_split_into_words=True`."
)
else:
raise ValueError(
- f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
+ " integers."
)
if return_offsets_mapping:
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 43d37e67cc50..f85dc73cb659 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -35,21 +35,16 @@
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
- EntryNotFoundError,
ExplicitEnum,
PaddingStrategy,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
TensorType,
add_end_docstrings,
- cached_path,
+ cached_file,
copy_func,
get_file_from_repo,
- hf_bucket_url,
is_flax_available,
is_offline_mode,
- is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
@@ -291,7 +286,10 @@ def tokens(self, batch_index: int = 0) -> List[str]:
`List[str]`: The list of tokens at that index.
"""
if not self._encodings:
- raise ValueError("tokens() is not available when using Python-based tokenizers")
+ raise ValueError(
+ "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
return self._encodings[batch_index].tokens
def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
@@ -312,7 +310,10 @@ def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
sequence.
"""
if not self._encodings:
- raise ValueError("sequence_ids() is not available when using Python-based tokenizers")
+ raise ValueError(
+ "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
return self._encodings[batch_index].sequence_ids
def words(self, batch_index: int = 0) -> List[Optional[int]]:
@@ -328,7 +329,10 @@ def words(self, batch_index: int = 0) -> List[Optional[int]]:
(several tokens will be mapped to the same word index if they are parts of that word).
"""
if not self._encodings:
- raise ValueError("words() is not available when using Python-based tokenizers")
+ raise ValueError(
+ "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
warnings.warn(
"`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
"but more self-explanatory `BatchEncoding.word_ids()` property.",
@@ -349,7 +353,10 @@ def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
(several tokens will be mapped to the same word index if they are parts of that word).
"""
if not self._encodings:
- raise ValueError("word_ids() is not available when using Python-based tokenizers")
+ raise ValueError(
+ "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
return self._encodings[batch_index].word_ids
def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
@@ -721,8 +728,10 @@ def convert_to_tensors(
"Please see if a fast version of this tokenizer is available to have this feature available."
)
raise ValueError(
- "Unable to create tensor, you should probably activate truncation and/or padding "
- "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
+ "Unable to create tensor, you should probably activate truncation and/or padding with"
+ " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your"
+ f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is"
+ " expected)."
)
return self
@@ -956,8 +965,9 @@ def bos_token(self) -> str:
"""
`str`: Beginning of sentence token. Log an error if used while not having been set.
"""
- if self._bos_token is None and self.verbose:
- logger.error("Using bos_token, but it is not set yet.")
+ if self._bos_token is None:
+ if self.verbose:
+ logger.error("Using bos_token, but it is not set yet.")
return None
return str(self._bos_token)
@@ -966,8 +976,9 @@ def eos_token(self) -> str:
"""
`str`: End of sentence token. Log an error if used while not having been set.
"""
- if self._eos_token is None and self.verbose:
- logger.error("Using eos_token, but it is not set yet.")
+ if self._eos_token is None:
+ if self.verbose:
+ logger.error("Using eos_token, but it is not set yet.")
return None
return str(self._eos_token)
@@ -976,8 +987,9 @@ def unk_token(self) -> str:
"""
`str`: Unknown token. Log an error if used while not having been set.
"""
- if self._unk_token is None and self.verbose:
- logger.error("Using unk_token, but it is not set yet.")
+ if self._unk_token is None:
+ if self.verbose:
+ logger.error("Using unk_token, but it is not set yet.")
return None
return str(self._unk_token)
@@ -987,8 +999,9 @@ def sep_token(self) -> str:
`str`: Separation token, to separate context and query in an input sequence. Log an error if used while not
having been set.
"""
- if self._sep_token is None and self.verbose:
- logger.error("Using sep_token, but it is not set yet.")
+ if self._sep_token is None:
+ if self.verbose:
+ logger.error("Using sep_token, but it is not set yet.")
return None
return str(self._sep_token)
@@ -997,8 +1010,9 @@ def pad_token(self) -> str:
"""
`str`: Padding token. Log an error if used while not having been set.
"""
- if self._pad_token is None and self.verbose:
- logger.error("Using pad_token, but it is not set yet.")
+ if self._pad_token is None:
+ if self.verbose:
+ logger.error("Using pad_token, but it is not set yet.")
return None
return str(self._pad_token)
@@ -1008,8 +1022,9 @@ def cls_token(self) -> str:
`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full
depth of the model. Log an error if used while not having been set.
"""
- if self._cls_token is None and self.verbose:
- logger.error("Using cls_token, but it is not set yet.")
+ if self._cls_token is None:
+ if self.verbose:
+ logger.error("Using cls_token, but it is not set yet.")
return None
return str(self._cls_token)
@@ -1019,8 +1034,9 @@ def mask_token(self) -> str:
`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
having been set.
"""
- if self._mask_token is None and self.verbose:
- logger.error("Using mask_token, but it is not set yet.")
+ if self._mask_token is None:
+ if self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@@ -1030,8 +1046,9 @@ def additional_special_tokens(self) -> List[str]:
`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been
set.
"""
- if self._additional_special_tokens is None and self.verbose:
- logger.error("Using additional_special_tokens, but it is not set yet.")
+ if self._additional_special_tokens is None:
+ if self.verbose:
+ logger.error("Using additional_special_tokens, but it is not set yet.")
return None
return [str(tok) for tok in self._additional_special_tokens]
@@ -1479,7 +1496,7 @@ def __init__(self, **kwargs):
self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
-
+ self._in_target_context_manager = False
super().__init__(**kwargs)
@property
@@ -1502,12 +1519,12 @@ def max_len_single_sentence(self, value) -> int:
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
if not self.deprecation_warnings.get("max_len_single_sentence", False):
logger.warning(
- "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
)
self.deprecation_warnings["max_len_single_sentence"] = True
else:
raise ValueError(
- "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
)
@max_len_sentences_pair.setter
@@ -1516,13 +1533,11 @@ def max_len_sentences_pair(self, value) -> int:
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
if not self.deprecation_warnings.get("max_len_sentences_pair", False):
logger.warning(
- "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up."
)
self.deprecation_warnings["max_len_sentences_pair"] = True
else:
- raise ValueError(
- "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
- )
+ raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.")
def _set_processor_class(self, processor_class: str):
"""Sets processor class as an attribute."""
@@ -1530,9 +1545,10 @@ def _set_processor_class(self, processor_class: str):
def __repr__(self) -> str:
return (
- f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
- f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
- f"padding_side='{self.padding_side}', truncation_side='{self.truncation_side}', special_tokens={self.special_tokens_map_extended})"
+ f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}',"
+ f" vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast},"
+ f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
+ f" special_tokens={self.special_tokens_map_extended})"
)
def get_vocab(self) -> Dict[str, int]:
@@ -1580,7 +1596,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`).
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
revision (`str`, *optional*, defaults to `"main"`):
@@ -1648,7 +1664,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
vocab_files = {}
init_configuration = {}
- if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
@@ -1668,9 +1685,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
- vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
+ vocab_files = {**cls.vocab_files_names, **additional_files_names}
- if "tokenizer_file" in vocab_files_target:
+ if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
@@ -1683,34 +1700,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
+ subfolder=subfolder,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
- vocab_files_target["tokenizer_file"] = fast_tokenizer_file
-
- # Look for the tokenizer files
- for file_id, file_name in vocab_files_target.items():
- if os.path.isdir(pretrained_model_name_or_path):
- if subfolder is not None:
- full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
- else:
- full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
- if not os.path.exists(full_file_name):
- logger.info(f"Didn't find file {full_file_name}. We won't load it.")
- full_file_name = None
- else:
- full_file_name = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=file_name,
- subfolder=subfolder,
- revision=revision,
- mirror=None,
- )
-
- vocab_files[file_id] = full_file_name
+ vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
@@ -1719,44 +1716,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
if file_path is None:
resolved_vocab_files[file_id] = None
else:
- try:
- resolved_vocab_files[file_id] = cached_path(
- file_path,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except FileNotFoundError as error:
- if local_files_only:
- unresolved_files.append(file_id)
- else:
- raise error
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
- "pass a token having permission to this repo with `use_auth_token` or log in with "
- "`huggingface-cli login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
- "for this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
- resolved_vocab_files[file_id] = None
-
- except ValueError:
- logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
- resolved_vocab_files[file_id] = None
+ resolved_vocab_files[file_id] = cached_file(
+ pretrained_model_name_or_path,
+ file_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
if len(unresolved_files) > 0:
logger.info(
@@ -1776,7 +1750,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
if file_id not in resolved_vocab_files:
continue
- if file_path == resolved_vocab_files[file_id]:
+ if is_local:
logger.info(f"loading file {file_path}")
else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
@@ -1788,6 +1762,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
*init_inputs,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
+ local_files_only=local_files_only,
**kwargs,
)
@@ -1800,6 +1775,7 @@ def _from_pretrained(
*init_inputs,
use_auth_token=None,
cache_dir=None,
+ local_files_only=False,
**kwargs
):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
@@ -1812,6 +1788,9 @@ def _from_pretrained(
pretrained_model_name_or_path,
copy.deepcopy(init_configuration),
*init_inputs,
+ use_auth_token=use_auth_token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
**(copy.deepcopy(kwargs)),
)
else:
@@ -1843,6 +1822,7 @@ def _from_pretrained(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
+ local_files_only=local_files_only,
)
config_tokenizer_class = config.tokenizer_class
except (OSError, ValueError, KeyError):
@@ -1873,10 +1853,10 @@ def _from_pretrained(
if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
logger.warning(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. "
- "It may result in unexpected tokenization. \n"
- f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n"
- f"The class this function is called from is '{cls.__name__}'."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this"
+ " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you"
+ f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called"
+ f" from is '{cls.__name__}'."
)
# Update with newly provided kwargs
@@ -1965,30 +1945,45 @@ def convert_added_tokens(obj: Union[AddedToken, Any]):
# Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
+ # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
+ # individual tokens would repeatedly rebuild a trie, which can be slow.
+ is_last_special = None
+ tokens = []
+
for token, index in added_tok_encoder_sorted:
- if has_tokenizer_file and index != len(tokenizer) and tokenizer.convert_tokens_to_ids(token) != index:
+ current_index = len(tokenizer) + len(tokens)
+ if has_tokenizer_file and index != current_index and tokenizer.convert_tokens_to_ids(token) != index:
# Tokenizer fast: added token needs to either be in the vocabulary with the proper index or the
# index is the current length of the tokenizer (not in vocabulary)
raise ValueError(
f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found "
f"{index}."
)
- elif not has_tokenizer_file and index != len(tokenizer):
+ elif not has_tokenizer_file and index != current_index:
# Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the
# current length of the tokenizer.
raise ValueError(
f"Non-consecutive added token '{token}' found. "
- f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
+ f"Should have index {current_index} but has index {index} in saved vocabulary."
)
- # Safe to call on a tokenizer fast even if token already there.
- tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
+ is_special = bool(token in special_tokens)
+ if is_last_special is None or is_last_special == is_special:
+ tokens.append(token)
+ else:
+ tokenizer.add_tokens(tokens, special_tokens=is_last_special)
+ tokens = [token]
+ is_last_special = is_special
+
+ if tokens:
+ tokenizer.add_tokens(tokens, special_tokens=is_last_special)
# Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning_advice(
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained."
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
+ " fine-tuned or trained."
)
return tokenizer
@@ -2035,15 +2030,11 @@ def save_pretrained(
filename_prefix: (`str`, *optional*):
A prefix to add to the names of the files saved by the tokenizer.
push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it.
-
-
-
- Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
- which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
- folder. Pass along `temp_dir=True` to use a temporary directory instead.
-
-
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs:
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
Returns:
A tuple of `str`: The files saved.
@@ -2052,11 +2043,13 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
+ os.makedirs(save_directory, exist_ok=True)
+
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- repo = self._create_or_get_repo(save_directory, **kwargs)
-
- os.makedirs(save_directory, exist_ok=True)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id, token = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
special_tokens_map_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
@@ -2104,13 +2097,15 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
custom_object_save(self, save_directory, config=tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(tokenizer_config, ensure_ascii=False))
+ out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(write_dict, ensure_ascii=False))
+ out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
@@ -2123,8 +2118,9 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
)
if push_to_hub:
- url = self._push_to_hub(repo, commit_message=commit_message)
- logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
+ self._upload_modified_files(
+ save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
+ )
return save_files
@@ -2154,7 +2150,7 @@ def _save_pretrained(
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
- out_str = json.dumps(added_vocab, ensure_ascii=False)
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"added tokens file saved in {added_tokens_file}")
@@ -2270,11 +2266,11 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
logger.warning(
- "Truncation was not explicitly activated but `max_length` is provided a specific value, "
- "please use `truncation=True` to explicitly truncate examples to max length. "
- "Defaulting to 'longest_first' truncation strategy. "
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
- "more precisely by providing a specific strategy to `truncation`."
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
+ " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
+ " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
+ " tokenizer you can select this strategy more precisely by providing a specific strategy to"
+ " `truncation`."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
@@ -2316,14 +2312,14 @@ def _get_padding_truncation_strategies(
if truncation is False and old_truncation_strategy != "do_not_truncate":
if verbose:
warnings.warn(
- "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
- "use `truncation=True` to truncate examples to a max length. You can give a specific "
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
- "maximal input size of the model (e.g. 512 for Bert). "
- " If you have pairs of inputs, you can give a specific truncation strategy selected among "
- "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
- "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
- "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, use"
+ " `truncation=True` to truncate examples to a max length. You can give a specific length with"
+ " `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input"
+ " size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific"
+ " truncation strategy selected among `truncation='only_first'` (will only truncate the first"
+ " sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the"
+ " pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence"
+ " in the pairs).",
FutureWarning,
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
@@ -2346,8 +2342,8 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
logger.warning(
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no padding."
+ "Asking to pad to max_length but no maximum length is provided and the model has no"
+ " predefined maximum length. Default to no padding."
)
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
@@ -2359,8 +2355,8 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
logger.warning(
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no truncation."
+ "Asking to truncate to max_length but no maximum length is provided and the model has"
+ " no predefined maximum length. Default to no truncation."
)
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
@@ -2384,7 +2380,7 @@ def _get_padding_truncation_strategies(
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
- f"Truncation and padding are both activated but "
+ "Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
@@ -2393,8 +2389,12 @@ def _get_padding_truncation_strategies(
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def __call__(
self,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ text_pair_target: Optional[
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
+ ] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
@@ -2417,15 +2417,85 @@ def __call__(
sequences.
Args:
- text (`str`, `List[str]`, `List[List[str]]`):
+ text (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- text_pair (`str`, `List[str]`, `List[List[str]]`):
+ text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- """
+ text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ """
+ # To avoid duplicating
+ all_kwargs = dict(
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ is_split_into_words=is_split_into_words,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ )
+ all_kwargs.update(kwargs)
+ if text is None and text_target is None:
+ raise ValueError("You need to specify either `text` or `text_target`.")
+ if text is not None:
+ # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
+ # input mode in this case.
+ if not self._in_target_context_manager:
+ self._switch_to_input_mode()
+ encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
+ if text_target is not None:
+ self._switch_to_target_mode()
+ target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
+ # Leave back tokenizer in input mode
+ self._switch_to_input_mode()
+
+ if text_target is None:
+ return encodings
+ elif text is None:
+ return target_encodings
+ else:
+ encodings["labels"] = target_encodings["input_ids"]
+ return encodings
+
+ def _call_one(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
# Input type checking for clearer error
def _is_valid_text_input(t):
if isinstance(t, str):
@@ -2467,11 +2537,13 @@ def _is_valid_text_input(t):
if is_batched:
if isinstance(text_pair, str):
raise TypeError(
- "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
+ "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as"
+ " `text`."
)
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus(
@@ -2826,7 +2898,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -3123,16 +3195,17 @@ def truncate_sequences(
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
- error_msg + "Please select another truncation strategy than "
+ error_msg
+ + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
- f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
- f"truncation strategy. So the returned list will always be empty even if some "
- f"tokens have been removed."
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
@@ -3165,7 +3238,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
@@ -3415,13 +3488,34 @@ def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Opt
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
+ def _switch_to_input_mode(self):
+ """
+ Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
+ """
+ pass
+
+ def _switch_to_target_mode(self):
+ """
+ Private method to put the tokenizer in target mode (when it has different modes for input/outputs)
+ """
+ pass
+
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
+ warnings.warn(
+ "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "
+ "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as "
+ "your input texts if you use the same keyword arguments, or in a separate call."
+ )
+ self._switch_to_target_mode()
+ self._in_target_context_manager = True
yield
+ self._in_target_context_manager = False
+ self._switch_to_input_mode()
@classmethod
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
@@ -3522,14 +3616,17 @@ def prepare_seq2seq_batch(
# docstyle-ignore
formatted_warning = """
`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
-`__call__` method to prepare your inputs and the tokenizer under the `as_target_tokenizer` context manager to prepare
-your targets.
+`__call__` method to prepare your inputs and targets.
Here is a short example:
+model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...)
+
+If you either need to use different keyword arguments for the source and target texts, you should do two calls like
+this:
+
model_inputs = tokenizer(src_texts, ...)
-with tokenizer.as_target_tokenizer():
- labels = tokenizer(tgt_texts, ...)
+labels = tokenizer(text_target=tgt_texts, ...)
model_inputs["labels"] = labels["input_ids"]
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py
index 4f85a842dd3d..a061685b0bf1 100644
--- a/src/transformers/tokenization_utils_fast.py
+++ b/src/transformers/tokenization_utils_fast.py
@@ -16,11 +16,13 @@
Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
see tokenization_utils.py
"""
+import copy
import json
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
+import tokenizers.pre_tokenizers as pre_tokenizers_fast
from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast
from tokenizers.decoders import Decoder as DecoderFast
@@ -103,7 +105,7 @@ def __init__(self, *args, **kwargs):
)
if tokenizer_object is not None:
- fast_tokenizer = tokenizer_object
+ fast_tokenizer = copy.deepcopy(tokenizer_object)
elif fast_tokenizer_file is not None and not from_slow:
# We have a serialization from tokenizers which let us directly build the backend
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
@@ -567,8 +569,8 @@ def _save_pretrained(
if self.slow_tokenizer_class is None and legacy_format is True:
raise ValueError(
- "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You "
- "might consider leaving the legacy_format at `None` or setting it to `False`."
+ "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
+ " might consider leaving the legacy_format at `None` or setting it to `False`."
)
save_slow = (
@@ -585,7 +587,7 @@ def _save_pretrained(
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
- out_str = json.dumps(added_vocab, ensure_ascii=False)
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
@@ -699,6 +701,8 @@ def train_new_from_iterator(
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
kwargs["unk_token"] = unk_token
+ if tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel":
+ kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 1a8cac0722e8..66879d659654 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -17,6 +17,8 @@
"""
import contextlib
+import functools
+import glob
import inspect
import math
import os
@@ -63,11 +65,18 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
-from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
+from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
+from .pytorch_utils import (
+ ALL_LAYERNORM_LAYERS,
+ is_torch_greater_or_equal_than_1_6,
+ is_torch_greater_or_equal_than_1_10,
+ is_torch_less_than_1_11,
+)
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
@@ -90,6 +99,7 @@
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
+ get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
@@ -103,19 +113,24 @@
BestRun,
EvalLoopOutput,
EvalPrediction,
+ FSDPOption,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
+ RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
+ enable_full_determinism,
+ find_executable_batch_size,
get_last_checkpoint,
has_length,
number_of_arguments,
+ seed_worker,
set_seed,
speed_metrics,
)
@@ -129,15 +144,20 @@
is_apex_available,
is_datasets_available,
is_in_notebook,
+ is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
+ is_torch_tensorrt_fx_available,
is_torch_tpu_available,
+ is_torchdynamo_available,
logging,
)
+from .utils.generic import ContextManagers
_is_torch_generator_available = False
-_is_native_amp_available = False
+_is_native_cuda_amp_available = False
+_is_native_cpu_amp_available = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
@@ -150,15 +170,17 @@
if is_apex_available():
from apex import amp
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
_is_torch_generator_available = True
- _is_native_amp_available = True
- from torch.cuda.amp import autocast
+ _is_native_cuda_amp_available = True
+
+if is_torch_greater_or_equal_than_1_10:
+ _is_native_cpu_amp_available = True
if is_datasets_available():
import datasets
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
@@ -175,8 +197,13 @@
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
if TYPE_CHECKING:
@@ -217,7 +244,7 @@ class Trainer:
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise.
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
- The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
+ The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
@@ -226,7 +253,7 @@ class Trainer:
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (`torch.utils.data.Dataset`, *optional*):
- The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
+ The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
@@ -297,7 +324,7 @@ def __init__(
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model
- set_seed(self.args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
@@ -322,12 +349,21 @@ def __init__(
else:
if model_init is not None:
warnings.warn(
- "`Trainer` requires either a `model` or `model_init` argument, but not both. "
- "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
FutureWarning,
)
self.model_init = model_init
+ if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+ raise ValueError(
+ f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+ "computes hidden states and does not accept any labels. You should choose a model with a head "
+ "suitable for your task like any of the `AutoModelForXxx` listed at "
+ "https://huggingface.co/docs/transformers/model_doc/auto."
+ )
+
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
self.is_model_parallel = True
else:
@@ -340,6 +376,10 @@ def __init__(
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
+ if len(args.fsdp) > 0:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
+ )
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
@@ -357,6 +397,31 @@ def __init__(
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
+ self.fsdp = None
+ if len(args.fsdp) > 0:
+ if args.deepspeed:
+ raise ValueError(
+ "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if args.local_rank == -1:
+ raise ValueError("Using fsdp only works in distributed training.")
+
+ # dep_version_check("torch>=1.12.0")
+ # Would have to update setup.py with torch>=1.12.0
+ # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
+ # below is the current alternative.
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
+ raise ValueError("FSDP requires PyTorch >= 1.12.0")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
+
+ if FSDPOption.FULL_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.FULL_SHARD
+ elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
+ self.fsdp = ShardingStrategy.SHARD_GRAD_OP
+ elif FSDPOption.NO_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.NO_SHARD
+
# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
@@ -364,12 +429,14 @@ def __init__(
# and we only use deepspeed for training at the moment
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP
+ # 5. FSDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or args.deepspeed
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
+ or (self.fsdp is not None)
):
self.place_model_on_device = False
@@ -398,11 +465,11 @@ def __init__(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
- if (self.sharded_ddp is not None or args.deepspeed) and (
+ if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
- "Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
+ "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
@@ -447,44 +514,96 @@ def __init__(
# Mixed precision setup
self.use_apex = False
- self.use_amp = False
+ self.use_cuda_amp = False
+ self.use_cpu_amp = False
+
+ # Mixed precision setup for SageMaker Model Parallel
+ if is_sagemaker_mp_enabled():
+ # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+ if args.bf16:
+ raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+ if args.fp16 != smp.state.cfg.fp16:
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
+ f"but FP16 provided in trainer argument is {args.fp16},"
+ f"setting to {smp.state.cfg.fp16}"
+ )
+ args.fp16 = smp.state.cfg.fp16
+ else:
+ # smp < 1.10 does not support fp16 in trainer.
+ if hasattr(smp.state.cfg, "fp16"):
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+ "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+ )
if args.fp16 or args.bf16:
if args.half_precision_backend == "auto":
- if _is_native_amp_available:
- args.half_precision_backend = "amp"
+ if args.device == torch.device("cpu"):
+ if args.fp16:
+ raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+ elif _is_native_cpu_amp_available:
+ args.half_precision_backend = "cpu_amp"
+ else:
+ raise ValueError("Tried to use cpu amp but native cpu amp is not available")
else:
- if args.bf16:
+ if _is_native_cuda_amp_available:
+ args.half_precision_backend = "cuda_amp"
+ elif args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
else:
args.half_precision_backend = "apex"
+
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
- if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
- if args.half_precision_backend == "amp":
- self.use_amp = True
+ if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
+ # deepspeed and SageMaker Model Parallel manage their own half precision
+ if args.half_precision_backend == "cuda_amp":
+ self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
- if is_sagemaker_mp_enabled():
- self.scaler = smp.amp.GradScaler()
- elif self.sharded_ddp is not None:
+ if self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
+ elif self.fsdp is not None:
+ if self.amp_dtype == torch.float16:
+ from torch.distributed.fsdp.sharded_grad_scaler import (
+ ShardedGradScaler as FSDPShardedGradScaler,
+ )
+
+ self.scaler = FSDPShardedGradScaler()
+ else:
+ self.do_grad_scaling = False
+ self.use_cuda_amp = False
+ self.amp_dtype = None
+
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
+ elif args.half_precision_backend == "cpu_amp":
+ self.use_cpu_amp = True
+ self.amp_dtype = torch.bfloat16
else:
if not is_apex_available():
raise ImportError(
- "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
+ "Using FP16 with APEX but APEX is not installed, please refer to"
+ " https://www.github.com/nvidia/apex."
)
self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
- if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
+ if (
+ is_sagemaker_mp_enabled()
+ and self.use_cuda_amp
+ and args.max_grad_norm is not None
+ and args.max_grad_norm > 0
+ ):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
@@ -511,9 +630,41 @@ def __init__(
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+ # Internal variables to keep track of the original batch size
+ self._train_batch_size = args.train_batch_size
+
# very last
self._memory_tracker.stop_and_update_metrics()
+ # torchdynamo
+ if args.torchdynamo:
+ if not is_torchdynamo_available():
+ raise RuntimeError("Torchdynamo is not installed.")
+ import torchdynamo
+ from torchdynamo.optimizations import backends
+ from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
+
+ def get_ctx():
+ # Normal
+ if args.torchdynamo == "eager":
+ return torchdynamo.optimize("eager")
+ elif args.torchdynamo == "nvfuser":
+ return torchdynamo.optimize(aot_autograd_speedup_strategy)
+ # TensorRT
+ if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
+ if not is_torch_tensorrt_fx_available():
+ raise RuntimeError("Torch-TensorRT FX path is not installed.")
+ if args.torchdynamo == "fx2trt-fp16":
+ return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
+ elif args.torchdynamo == "fx2trt":
+ return torchdynamo.optimize(backends.fx2trt_compiler)
+ else:
+ raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
+
+ self.ctx_manager_torchdynamo = get_ctx()
+ else:
+ self.ctx_manager_torchdynamo = contextlib.nullcontext()
+
def add_callback(self, callback):
"""
Add a callback to the current list of [`~transformer.TrainerCallback`].
@@ -558,27 +709,31 @@ def _move_model_to_device(self, model, device):
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
- def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
- if not self.args.remove_unused_columns:
- return dataset
+ def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
- self._signature_columns += ["label", "label_ids"]
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
- ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0:
- dset_description = "" if description is None else f"in the {description} set "
+ dset_description = "" if description is None else f"in the {description} set"
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
- f" you can safely ignore this message."
+ " you can safely ignore this message."
)
- columns = [k for k in self._signature_columns if k in dataset.column_names]
+ columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
@@ -588,6 +743,24 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio
else:
return dataset.remove_columns(ignored_columns)
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
@@ -674,14 +847,17 @@ def get_train_dataloader(self) -> DataLoader:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
+ data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
- batch_size=self.args.train_batch_size,
+ batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
@@ -690,7 +866,7 @@ def get_train_dataloader(self) -> DataLoader:
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -699,12 +875,13 @@ def get_train_dataloader(self) -> DataLoader:
return DataLoader(
train_dataset,
- batch_size=self.args.train_batch_size,
+ batch_size=self._train_batch_size,
sampler=train_sampler,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
+ worker_init_fn=seed_worker,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
@@ -744,15 +921,18 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
- If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by
- the `model.forward()` method are automatically removed. It must implement `__len__`.
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ data_collator = self.data_collator
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
@@ -766,7 +946,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
return DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -777,7 +957,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
@@ -791,11 +971,15 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
- The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
- method are automatically removed. It must implement `__len__`.
+ The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. It must implement `__len__`.
"""
+ data_collator = self.data_collator
+
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
@@ -809,7 +993,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -821,8 +1005,9 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
test_dataset,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -835,7 +1020,12 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
`create_scheduler`) in a subclass.
"""
self.create_optimizer()
- self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
+ if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+ # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+ optimizer = self.optimizer.optimizer
+ else:
+ optimizer = self.optimizer
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def create_optimizer(self):
"""
@@ -847,7 +1037,7 @@ def create_optimizer(self):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
- decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
@@ -937,6 +1127,10 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = torch.optim.SGD
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = torch.optim.Adagrad
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
@@ -964,6 +1158,10 @@ def num_examples(self, dataloader: DataLoader) -> int:
dataloader.dataset does not exist or has no length, estimates as best it can
"""
try:
+ dataset = dataloader.dataset
+ # Special case for IterableDatasetShard, we need to dig deeper
+ if isinstance(dataset, IterableDatasetShard):
+ return len(dataloader.dataset.dataset)
return len(dataloader.dataset)
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
return len(dataloader) * self.args.per_device_train_batch_size
@@ -987,7 +1185,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
for key, value in params.items():
if not hasattr(self.args, key):
logger.warning(
- f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+ " `TrainingArguments`."
)
continue
old_attr = getattr(self.args, key, None)
@@ -1008,16 +1207,14 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
- def _report_to_hp_search(
- self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
- ):
+ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
- trial.report(self.objective, epoch)
+ trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
@@ -1055,7 +1252,58 @@ def call_model_init(self, trial=None):
return model
- def _wrap_model(self, model, training=True):
+ def torch_jit_model_eval(self, model, dataloader, training=False):
+ if not training:
+ if dataloader is None:
+ logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+ return model
+ jit_inputs = []
+ example_batch = next(iter(dataloader))
+ for key in example_batch:
+ example_tensor = torch.ones_like(example_batch[key])
+ jit_inputs.append(example_tensor)
+ jit_inputs = tuple(jit_inputs)
+ try:
+ jit_model = model.eval()
+ with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
+ jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
+ jit_model = torch.jit.freeze(jit_model)
+ jit_model(**example_batch)
+ model = jit_model
+ except (RuntimeError, TypeError) as e:
+ logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+ return model
+
+ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
+ if not is_ipex_available():
+ raise ImportError(
+ "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
+ " to https://github.com/intel/intel-extension-for-pytorch."
+ )
+
+ import intel_extension_for_pytorch as ipex
+
+ if not training:
+ model.eval()
+ model = ipex.optimize(model, dtype=dtype, level="O1")
+ else:
+ if not model.training:
+ model.train()
+ model, self.optimizer = ipex.optimize(
+ model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
+ )
+
+ return model
+
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if self.args.use_ipex:
+ dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
+ model = self.ipex_optimize_model(model, training, dtype=dtype)
+
+ if self.args.jit_mode_eval:
+ model = self.torch_jit_model_eval(model, dataloader, training)
+
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
@@ -1101,7 +1349,55 @@ def _wrap_model(self, model, training=True):
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
+ # Distributed training using PyTorch FSDP
+ elif self.fsdp is not None:
+ # PyTorch FSDP!
+ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
+
+ if FSDPOption.OFFLOAD in self.args.fsdp:
+ cpu_offload = CPUOffload(offload_params=True)
+ else:
+ cpu_offload = CPUOffload(offload_params=False)
+ auto_wrap_policy = None
+ if FSDPOption.AUTO_WRAP in self.args.fsdp:
+ if self.args.fsdp_min_num_params > 0:
+ auto_wrap_policy = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
+ )
+ elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
+ transformer_cls_to_wrap = get_module_class_from_name(
+ model, self.args.fsdp_transformer_layer_cls_to_wrap
+ )
+ if transformer_cls_to_wrap is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ # Transformer layer class to wrap
+ transformer_layer_cls={transformer_cls_to_wrap},
+ )
+ mixed_precision_policy = None
+ dtype = None
+ if self.args.fp16:
+ dtype = torch.float16
+ elif self.args.bf16:
+ dtype = torch.bfloat16
+ if dtype is not None:
+ mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
+ if type(model) != FSDP:
+ # XXX: Breaking the self.model convention but I see no way around it for now.
+ self.model = model = FSDP(
+ model,
+ sharding_strategy=self.fsdp,
+ cpu_offload=cpu_offload,
+ auto_wrap_policy=auto_wrap_policy,
+ mixed_precision=mixed_precision_policy,
+ )
+ if FSDPOption.OFFLOAD not in self.args.fsdp:
+ model.to(self.args.device)
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
@@ -1182,7 +1478,7 @@ def train(
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
- set_seed(args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
@@ -1194,7 +1490,7 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
- if resume_from_checkpoint is not None:
+ if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped
@@ -1203,6 +1499,20 @@ def train(
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
+ inner_training_loop = find_executable_batch_size(
+ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+ )
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self._train_batch_size = batch_size
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
@@ -1239,7 +1549,8 @@ def train(
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
- f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
@@ -1247,13 +1558,17 @@ def train(
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
- "Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)."
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = (
- self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
+ self.sharded_ddp is not None
+ and self.sharded_ddp != ShardedDDPOption.SIMPLE
+ or is_sagemaker_mp_enabled()
+ or self.fsdp is not None
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
@@ -1276,6 +1591,9 @@ def train(
model = self._wrap_model(self.model_wrapped)
+ if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
+ self._load_from_checkpoint(resume_from_checkpoint, model)
+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
@@ -1363,7 +1681,7 @@ def train(
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler
)
- if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
+ if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
@@ -1465,7 +1783,9 @@ def train(
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
- if hasattr(self.optimizer, "clip_grad_norm"):
+ if is_sagemaker_mp_enabled() and args.fp16:
+ self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
@@ -1513,7 +1833,7 @@ def train(
break
if step < 0:
logger.warning(
- f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ "There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
@@ -1545,6 +1865,8 @@ def train(
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
self._load_best_model()
@@ -1574,13 +1896,17 @@ def train(
return TrainOutput(self.state.global_step, train_loss, metrics)
- def _load_from_checkpoint(self, resume_from_checkpoint):
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+
+ if model is None:
+ model = self.model
+
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
- logger.info(f"Loading model from {resume_from_checkpoint}).")
+ logger.info(f"Loading model from {resume_from_checkpoint}.")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
@@ -1596,45 +1922,93 @@ def _load_from_checkpoint(self, resume_from_checkpoint):
# will be resumed in deepspeed_init
pass
elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
- # We load the model state dict on the CPU to avoid an OOM error.
- state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
- load_result = self.model.load_state_dict(state_dict, strict=False)
- self._issue_warnings_after_load(load_result)
-
- # release memory
- del state_dict
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if hasattr(self.args, "fp16") and self.args.fp16 is True:
+ logger.warning(
+ "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
+ )
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ # release memory
+ del state_dict
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ load_result = model.load_state_dict(state_dict, strict=False)
+ # release memory
+ del state_dict
+ self._issue_warnings_after_load(load_result)
else:
# We load the sharded checkpoint
- load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
- self._issue_warnings_after_load(load_result)
+ load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
-
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path):
if self.deepspeed:
+
+ if self.model_wrapped is not None:
+ # this removes the pre-hooks from the previous engine
+ self.model_wrapped.destroy()
+ self.model_wrapped = None
+
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
- deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
+ deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
+ self,
+ num_training_steps=self.args.max_steps,
+ resume_from_checkpoint=self.state.best_model_checkpoint,
+ )
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
- self.deepspeed.load_checkpoint(
- self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
- )
else:
- # We load the model state dict on the CPU to avoid an OOM error.
- state_dict = torch.load(best_model_path, map_location="cpu")
- # If the model is on the GPU, it still works!
- load_result = self.model.load_state_dict(state_dict, strict=False)
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=self.state.best_model_checkpoint,
+ tag=WEIGHTS_NAME,
+ partial=False,
+ load_optimizer=False,
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ # If the model is on the GPU, it still works!
+ load_result = model.load_state_dict(state_dict, strict=False)
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+ elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
+ load_result = load_sharded_checkpoint(
+ model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+ )
+ if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
- elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
- # Best model is a sharded checkpoint
- load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
- self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
@@ -1680,7 +2054,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
- self._report_to_hp_search(trial, epoch, metrics)
+ self._report_to_hp_search(trial, self.state.global_step, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
@@ -1691,12 +2065,12 @@ def _load_rng_state(self, checkpoint):
if checkpoint is None:
return
- local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
- if local_rank != -1:
- rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
- if not os.path.isfile(os.path.join(checkpoint, rng_file)):
+ if self.args.world_size > 1:
+ process_index = self.args.process_index
+ rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+ if not os.path.isfile(rng_file):
logger.info(
- f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
+ f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
@@ -1772,17 +2146,21 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
- if smp.rdp_rank() == 0:
- # Consolidate the state dict on all processed of rdp_rank 0
- opt_state_dict = self.optimizer.state_dict()
- # Save it and the scheduler on the main process
- if self.args.should_save:
- torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
- with warnings.catch_warnings(record=True) as caught_warnings:
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- if self.do_grad_scaling:
- torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+ smp.barrier()
+ if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+ smp.save(
+ opt_state_dict,
+ os.path.join(output_dir, OPTIMIZER_NAME),
+ partial=True,
+ v3=smp.state.cfg.shard_optimizer_state,
+ )
+ if self.args.should_save:
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
@@ -1831,11 +2209,11 @@ def _save_checkpoint(self, model, trial, metrics=None):
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
- local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
- if local_rank == -1:
+
+ if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
- torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
+ torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
@@ -1853,9 +2231,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return
- if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
- os.path.join(checkpoint, SCHEDULER_NAME)
- ):
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+ if is_sagemaker_mp_enabled()
+ else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+ )
+ if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
@@ -1871,9 +2252,27 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
- self.optimizer.load_state_dict(
- torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
- )
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+ # Optimizer checkpoint was saved with smp >= 1.10
+ def opt_load_hook(mod, opt):
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ else:
+ # Optimizer checkpoint was saved with smp < 1.10
+ def opt_load_hook(mod, opt):
+ if IS_SAGEMAKER_MP_POST_1_10:
+ opt.load_state_dict(
+ smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+ )
+ else:
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ self.model_wrapped.register_post_step_hook(opt_load_hook)
+ else:
+ self.optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
+ )
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
@@ -2024,16 +2423,37 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
return inputs
+ def compute_loss_context_manager(self):
+ """
+ A helper wrapper to group together context managers.
+ """
+ return ContextManagers(
+ [
+ self.torchdynamo_smart_context_manager(),
+ self.autocast_smart_context_manager(),
+ ]
+ )
+
+ def torchdynamo_smart_context_manager(self):
+ """
+ A helper wrapper that creates an appropriate context manager for `torchdynamo`.
+ """
+ return self.ctx_manager_torchdynamo
+
def autocast_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
- if self.use_amp:
- if version.parse(torch.__version__) >= version.parse("1.10"):
- ctx_manager = autocast(dtype=self.amp_dtype)
+ if self.use_cuda_amp or self.use_cpu_amp:
+ if is_torch_greater_or_equal_than_1_10:
+ ctx_manager = (
+ torch.cpu.amp.autocast(dtype=self.amp_dtype)
+ if self.use_cpu_amp
+ else torch.cuda.amp.autocast(dtype=self.amp_dtype)
+ )
else:
- ctx_manager = autocast()
+ ctx_manager = torch.cuda.amp.autocast()
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
@@ -2061,11 +2481,10 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
- scaler = self.scaler if self.do_grad_scaling else None
- loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
@@ -2105,8 +2524,16 @@ def compute_loss(self, model, inputs, return_outputs=False):
self._past = outputs[self.args.past_index]
if labels is not None:
- loss = self.label_smoother(outputs, labels)
+ if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
else:
+ if isinstance(outputs, dict) and "loss" not in outputs:
+ raise ValueError(
+ "The model did not return a loss from the inputs, only the following keys: "
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+ )
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
@@ -2145,11 +2572,17 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
+ os.makedirs(output_dir, exist_ok=True)
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+ Path(os.path.join(output_dir, "user_content.pt")).touch()
elif (
- ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
+ or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ or self.fsdp is not None
):
state_dict = self.model.state_dict()
@@ -2177,8 +2610,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
# This must be called on all ranks
if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
- "deepspeed.save_16bit_model didn't save the model, since stage3_gather_16bit_weights_on_model_save=false. "
- "Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
+ "deepspeed.save_16bit_model didn't save the model, since"
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+ " zero_to_fp32.py to recover weights"
)
self.deepspeed.save_checkpoint(output_dir)
@@ -2318,8 +2752,8 @@ def evaluate(
Args:
eval_dataset (`Dataset`, *optional*):
- Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not
- accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`Lst[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
@@ -2426,6 +2860,7 @@ def predict(
)
)
+ self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
@@ -2459,7 +2894,7 @@ def evaluation_loop(
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
- model = self._wrap_model(self.model, training=False)
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
@@ -2517,7 +2952,7 @@ def evaluation_loop(
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
- inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available():
xm.mark_step()
@@ -2756,7 +3191,7 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
@@ -2766,7 +3201,7 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
@@ -3020,7 +3455,7 @@ def prediction_loop(
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
- model = self._wrap_model(self.model, training=False)
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
@@ -3065,7 +3500,7 @@ def prediction_loop(
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
- inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if loss is not None:
losses = loss.repeat(batch_size)
diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py
index 92abe1ed5063..8749e5f3f574 100644
--- a/src/transformers/trainer_callback.py
+++ b/src/transformers/trainer_callback.py
@@ -262,6 +262,12 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
"""
pass
+ def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
+ """
+ Event called after a successful prediction.
+ """
+ pass
+
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a checkpoint save.
@@ -372,6 +378,9 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
control.should_evaluate = False
return self.call_event("on_evaluate", args, state, control, metrics=metrics)
+ def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+ return self.call_event("on_predict", args, state, control, metrics=metrics)
+
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_save = False
return self.call_event("on_save", args, state, control)
@@ -484,6 +493,12 @@ def on_evaluate(self, args, state, control, **kwargs):
self.prediction_bar.close()
self.prediction_bar = None
+ def on_predict(self, args, state, control, **kwargs):
+ if state.is_local_process_zero:
+ if self.prediction_bar is not None:
+ self.prediction_bar.close()
+ self.prediction_bar = None
+
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
@@ -556,7 +571,8 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
if metric_value is None:
logger.warning(
- f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
+ f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
+ " is disabled"
)
return
diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py
index ac83826e40ca..e1ad471b07a9 100644
--- a/src/transformers/trainer_pt_utils.py
+++ b/src/transformers/trainer_pt_utils.py
@@ -43,7 +43,7 @@
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
@@ -55,8 +55,22 @@
logger = logging.get_logger(__name__)
+def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
+ if isinstance(tensor_or_array, torch.Tensor):
+ if hasattr(torch, "atleast_1d"):
+ tensor_or_array = torch.atleast_1d(tensor_or_array)
+ elif tensor_or_array.ndim < 1:
+ tensor_or_array = tensor_or_array[None]
+ else:
+ tensor_or_array = np.atleast_1d(tensor_or_array)
+ return tensor_or_array
+
+
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
+ tensor1 = atleast_1d(tensor1)
+ tensor2 = atleast_1d(tensor2)
+
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
return torch.cat((tensor1, tensor2), dim=0)
@@ -72,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
+ array1 = atleast_1d(array1)
+ array2 = atleast_1d(array2)
+
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), axis=0)
@@ -149,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
- if tensors.ndim == 0:
- tensors = tensors[None]
+ tensors = atleast_1d(tensors)
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
@@ -160,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
- if len(tensor.shape) <= 0:
- tensor = tensor[None]
+ tensor = atleast_1d(tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
@@ -451,8 +466,12 @@ class LabelSmoother:
epsilon: float = 0.1
ignore_index: int = -100
- def __call__(self, model_output, labels):
+ def __call__(self, model_output, labels, shift_labels=False):
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
+ if shift_labels:
+ logits = logits[..., :-1, :].contiguous()
+ labels = labels[..., 1:].contiguous()
+
log_probs = -nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1)
@@ -539,6 +558,12 @@ def __init__(
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
+ elif isinstance(lengths, torch.Tensor):
+ logger.info(
+ "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
+ )
+ lengths = lengths.tolist()
+
self.lengths = lengths
self.generator = generator
@@ -595,6 +620,13 @@ def __init__(
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
+ elif isinstance(lengths, torch.Tensor):
+ logger.info(
+ "If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
+ " List[int]..."
+ )
+ lengths = lengths.tolist()
+
self.lengths = lengths
# If the dataset length is evenly divisible by # of replicas, then there
@@ -803,7 +835,7 @@ def _get_learning_rate(self):
last_lr = (
# backward compatibility for pytorch schedulers
self.lr_scheduler.get_last_lr()[0]
- if version.parse(torch.__version__) >= version.parse("1.4")
+ if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
return last_lr
@@ -1001,19 +1033,34 @@ def get_parameter_names(model, forbidden_layer_types):
return result
+def get_module_class_from_name(module, name):
+ """
+ Gets a class from a module by its name.
+
+ Args:
+ module (`torch.nn.Module`): The module to get the class from.
+ name (`str`): The name of the class.
+ """
+ modules_children = list(module.children())
+ if module.__class__.__name__ == name:
+ return module.__class__
+ elif len(modules_children) == 0:
+ return
+ else:
+ for child_module in modules_children:
+ module_class = get_module_class_from_name(child_module, name)
+ if module_class is not None:
+ return module_class
+
+
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@smp.step()
- def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
- with torch.cuda.amp.autocast(enabled=(scaler is not None)):
- outputs = model(**inputs)
-
+ def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
+ outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
- if scaler is not None:
- loss = scaler.scale(loss).squeeze()
-
model.backward(loss)
return loss
@@ -1031,7 +1078,7 @@ def smp_gather(tensor):
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
- all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
+ all_tensors = [atleast_1d(t) for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):
diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py
index 5513b58bef94..02ce3d393b9e 100644
--- a/src/transformers/trainer_seq2seq.py
+++ b/src/transformers/trainer_seq2seq.py
@@ -33,8 +33,7 @@ def evaluate(
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
- max_length: Optional[int] = None,
- num_beams: Optional[int] = None,
+ **gen_kwargs
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
@@ -46,8 +45,8 @@ def evaluate(
Args:
eval_dataset (`Dataset`, *optional*):
- Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not
- accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
@@ -60,13 +59,23 @@ def evaluate(
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
+ gen_kwargs:
+ Additional `generate` specific kwargs.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
- self._max_length = max_length if max_length is not None else self.args.generation_max_length
- self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
+
+ gen_kwargs = gen_kwargs.copy()
+ gen_kwargs["max_length"] = (
+ gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
+ )
+ gen_kwargs["num_beams"] = (
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
+ )
+ self._gen_kwargs = gen_kwargs
+
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
@@ -74,8 +83,7 @@ def predict(
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
- max_length: Optional[int] = None,
- num_beams: Optional[int] = None,
+ **gen_kwargs
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
@@ -85,7 +93,7 @@ def predict(
Args:
test_dataset (`Dataset`):
- Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+ Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. Has to implement the method `__len__`
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
@@ -98,6 +106,8 @@ def predict(
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
+ gen_kwargs:
+ Additional `generate` specific kwargs.
@@ -114,8 +124,16 @@ def predict(
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
- self._max_length = max_length if max_length is not None else self.args.generation_max_length
- self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
+
+ gen_kwargs = gen_kwargs.copy()
+ gen_kwargs["max_length"] = (
+ gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
+ )
+ gen_kwargs["num_beams"] = (
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
+ )
+ self._gen_kwargs = gen_kwargs
+
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
@@ -155,11 +173,17 @@ def prediction_step(
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
- gen_kwargs = {
- "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
- "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
- "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
- }
+ gen_kwargs = self._gen_kwargs.copy()
+ gen_kwargs["max_length"] = (
+ gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
+ )
+ gen_kwargs["num_beams"] = (
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
+ )
+ default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
+ gen_kwargs["synced_gpus"] = (
+ gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
+ )
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
@@ -183,7 +207,7 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py
index 71c2e691d2a7..737dd4deaf68 100644
--- a/src/transformers/trainer_tf.py
+++ b/src/transformers/trainer_tf.py
@@ -34,7 +34,14 @@
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
-from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
+from .trainer_utils import (
+ PREFIX_CHECKPOINT_DIR,
+ EvalPrediction,
+ IntervalStrategy,
+ PredictionOutput,
+ enable_full_determinism,
+ set_seed,
+)
from .training_args_tf import TFTrainingArguments
from .utils import logging
@@ -134,7 +141,7 @@ def __init__(
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
- set_seed(self.args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
def get_train_tfdataset(self) -> tf.data.Dataset:
"""
diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py
index 4450bfde646e..579e5d1dc24c 100644
--- a/src/transformers/trainer_utils.py
+++ b/src/transformers/trainer_utils.py
@@ -25,7 +25,7 @@
import re
import threading
import time
-from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
@@ -36,6 +36,7 @@
is_torch_available,
is_torch_cuda_available,
is_torch_tpu_available,
+ requires_backends,
)
@@ -46,6 +47,39 @@
import tensorflow as tf
+def seed_worker(_):
+ """
+ Helper function to set worker seed during Dataloader initialization.
+ """
+ worker_seed = torch.initial_seed() % 2**32
+ set_seed(worker_seed)
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
+ """
+ # set seed first
+ set_seed(seed)
+
+ if is_torch_available():
+ # Ā Enable PyTorch deterministic mode. This potentially requires either the environment
+ # Ā variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if is_tf_available():
+ tf.config.experimental.enable_op_determinism()
+
+
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
@@ -211,7 +245,7 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
def default_hp_space_ray(trial) -> Dict[str, float]:
from .integrations import is_ray_tune_available
- assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
+ assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
from ray import tune
return {
@@ -273,7 +307,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
- if is_torch_tpu_available():
+ if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0
@@ -284,7 +318,7 @@ def total_processes_number(local_rank):
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
- if is_torch_tpu_available():
+ if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
@@ -355,6 +389,7 @@ class TrainerMemoryTracker:
stages = {
"__init__": "init",
"train": "train",
+ "_inner_training_loop": "train",
"evaluate": "eval",
"predict": "test",
}
@@ -582,3 +617,81 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
+
+
+def find_executable_batch_size(
+ function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
+):
+ """
+ Args:
+ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
+ CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as
+ its first argument.
+ function (`callable`, *optional*)
+ A function to wrap
+ starting_batch_size (`int`, *optional*)
+ The batch size to try and fit into memory
+ auto_find_batch_size (`bool`, *optional*)
+ If False, will just execute `function`
+ """
+ if function is None:
+ return functools.partial(
+ find_executable_batch_size,
+ starting_batch_size=starting_batch_size,
+ auto_find_batch_size=auto_find_batch_size,
+ )
+
+ if auto_find_batch_size:
+ requires_backends(find_executable_batch_size, "accelerate")
+ import accelerate.memory_utils as mem_utils
+
+ return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
+
+ return functools.partial(function, batch_size=starting_batch_size)
+
+
+class FSDPOption(ExplicitEnum):
+ FULL_SHARD = "full_shard"
+ SHARD_GRAD_OP = "shard_grad_op"
+ NO_SHARD = "no_shard"
+ OFFLOAD = "offload"
+ AUTO_WRAP = "auto_wrap"
+
+
+class RemoveColumnsCollator:
+ """Wrap the data collator to remove unused columns before they are passed to the collator."""
+
+ def __init__(
+ self,
+ data_collator,
+ signature_columns,
+ logger=None,
+ model_name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ self.data_collator = data_collator
+ self.signature_columns = signature_columns
+ self.logger = logger
+ self.description = description
+ self.model_name = model_name
+ self.message_logged = False
+
+ def _remove_columns(self, feature: dict) -> dict:
+ if not isinstance(feature, dict):
+ return feature
+ if not self.message_logged and self.logger and self.model_name:
+ ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if self.description is None else f"in the {self.description} set"
+ self.logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
+ " you can safely ignore this message."
+ )
+ self.message_logged = True
+ return {k: v for k, v in feature.items() if k in self.signature_columns}
+
+ def __call__(self, features: List[dict]):
+ features = [self._remove_columns(feature) for feature in features]
+ return self.data_collator(features)
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index cc0a5ec83570..e662d6fca4fd 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -20,21 +20,32 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
from .debug_utils import DebugOption
-from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
+from .trainer_utils import (
+ EvaluationStrategy,
+ FSDPOption,
+ HubStrategy,
+ IntervalStrategy,
+ SchedulerType,
+ ShardedDDPOption,
+)
from .utils import (
ExplicitEnum,
cached_property,
+ ccl_version,
get_full_repo_name,
+ is_accelerate_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
- is_torch_bf16_available,
+ is_torch_bf16_cpu_available,
+ is_torch_bf16_gpu_available,
is_torch_tf32_available,
is_torch_tpu_available,
logging,
+ requires_backends,
torch_required,
)
@@ -43,7 +54,7 @@
import torch
import torch.distributed as dist
-if is_torch_tpu_available():
+if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
@@ -69,6 +80,15 @@ def default_logdir() -> str:
return os.path.join("runs", current_time + "_" + socket.gethostname())
+def get_int_from_env(env_keys, default):
+ """Returns the first positive env value found in the `env_keys` list or the default."""
+ for e in env_keys:
+ val = int(os.environ.get(e, -1))
+ if val >= 0:
+ return val
+ return default
+
+
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
@@ -80,6 +100,8 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
+ SGD = "sgd"
+ ADAGRAD = "adagrad"
@dataclass
@@ -227,9 +249,14 @@ class TrainingArguments:
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
+ jit_mode_eval (`bool`, *optional*, defaults to `False`):
+ Whether or not to use PyTorch jit trace for inference.
+ use_ipex (`bool`, *optional*, defaults to `False`):
+ Use Intel extension for PyTorch when it is available. [IPEX
+ installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
- NVIDIA architecture. This is an experimental API and it may change.
+ NVIDIA architecture or using CPU (no_cuda). This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
@@ -238,9 +265,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
- The backend to use for mixed precision training. Must be one of `"auto"`, `"amp"` or `"apex"`. `"auto"`
- will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the
- requested backend.
+ The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
+ `"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
+ will force the requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
@@ -280,8 +307,7 @@ class TrainingArguments:
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
set to warn or lower (default), `False` otherwise.
remove_unused_columns (`bool`, *optional*, defaults to `True`):
- If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the
- model forward method.
+ Whether or not to automatically remove the columns unused by the model forward method.
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
label_names (`List[str]`, *optional*):
@@ -294,8 +320,8 @@ class TrainingArguments:
- When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in the case
- it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+ When set to `True`, the parameters `save_strategy` needs to be the same as `evaluation_strategy`, and in
+ the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
@@ -331,6 +357,18 @@ class TrainingArguments:
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for `False` and `["simple"]` for `True`.
+ fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`):
+ Use PyTorch Distributed Parallel Training (in distributed training only).
+
+ A list of options along the following:
+
+ - `"full_shard"`: Shard parameters, gradients and optimizer states.
+ - `"shard_grad_op"`: Shard optimizer states and gradients.
+ - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
+ `"shard_grad_op"`).
+ - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
+ fsdp_min_num_params (`int`, *optional*, defaults to `0`):
+ FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@@ -376,8 +414,8 @@ class TrainingArguments:
down the training and evaluation speed.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push the model to the Hub every time the model is saved. If this is activated,
- `output_dir` will begin a git directory synced with the the repo (determined by `hub_model_id`) and the
- content will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
+ `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content
+ will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
[`~Trainer.save_model`] will also trigger a push.
@@ -398,7 +436,7 @@ class TrainingArguments:
`"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the
name of `output_dir`.
- Will default to to the name of `output_dir`.
+ Will default to the name of `output_dir`.
hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
Defines the scope of what is pushed to the Hub and when. Possible values are:
@@ -424,6 +462,21 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
+ auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+ Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
+ CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+ full_determinism (`bool`, *optional*, defaults to `False`)
+ If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
+ distributed training
+ torchdynamo (`str`, *optional*):
+ The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
+ "nvfuser]. This is an experimental API and subject to change.
+ ray_scope (`str`, *optional*, defaults to `"last"`):
+ The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
+ then use the last checkpoint of all trials, compare those, and select the best one. However, other options
+ are also available. See the [Ray documentation](
+ https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
+ more options.
"""
output_dir: str = field(
@@ -442,7 +495,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
- evaluation_strategy: IntervalStrategy = field(
+ evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
@@ -461,15 +514,19 @@ class TrainingArguments:
per_gpu_train_batch_size: Optional[int] = field(
default=None,
metadata={
- "help": "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
- "Batch size per GPU/TPU core/CPU for training."
+ "help": (
+ "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for training."
+ )
},
)
per_gpu_eval_batch_size: Optional[int] = field(
default=None,
metadata={
- "help": "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
- "Batch size per GPU/TPU core/CPU for evaluation."
+ "help": (
+ "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for evaluation."
+ )
},
)
@@ -485,7 +542,10 @@ class TrainingArguments:
eval_delay: Optional[float] = field(
default=0,
metadata={
- "help": "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the evaluation_strategy."
+ "help": (
+ "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
+ " evaluation_strategy."
+ )
},
)
@@ -501,7 +561,7 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
- lr_scheduler_type: SchedulerType = field(
+ lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
@@ -513,7 +573,11 @@ class TrainingArguments:
log_level: Optional[str] = field(
default="passive",
metadata={
- "help": "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'.",
+ "help": (
+ "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
+ " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
+ " lets the application set the level. Defaults to 'passive'."
+ ),
"choices": trainer_log_levels.keys(),
},
)
@@ -527,18 +591,21 @@ class TrainingArguments:
log_on_each_node: bool = field(
default=True,
metadata={
- "help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
+ "help": (
+ "When doing a multinode distributed training, whether to log once per node or just once on the main"
+ " node."
+ )
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
- logging_strategy: IntervalStrategy = field(
+ logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
- save_strategy: IntervalStrategy = field(
+ save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
@@ -555,16 +622,34 @@ class TrainingArguments:
save_on_each_node: bool = field(
default=False,
metadata={
- "help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
+ "help": (
+ "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
+ " only on the main one"
+ )
},
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
- data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ jit_mode_eval: bool = field(
+ default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
+ )
+ use_ipex: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Use Intel extension for PyTorch when it is available, installation:"
+ " 'https://github.com/intel/intel-extension-for-pytorch'"
+ )
+ },
+ )
bf16: bool = field(
default=False,
metadata={
- "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
+ "help": (
+ "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
+ " architecture or using CPU (no_cuda). This is an experimental API and it may change."
+ )
},
)
fp16: bool = field(
@@ -582,26 +667,35 @@ class TrainingArguments:
)
half_precision_backend: str = field(
default="auto",
- metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
+ metadata={
+ "help": "The backend to be used for half precision.",
+ "choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
+ },
)
bf16_full_eval: bool = field(
default=False,
metadata={
- "help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change."
+ "help": (
+ "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
+ " change."
+ )
},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
- tf32: bool = field(
+ tf32: Optional[bool] = field(
default=None,
metadata={
- "help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
+ "help": (
+ "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental"
+ " API and it may change."
+ )
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
- xpu_backend: str = field(
+ xpu_backend: Optional[str] = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
@@ -611,26 +705,33 @@ class TrainingArguments:
tpu_metrics_debug: bool = field(
default=False,
metadata={
- "help": "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+ "help": (
+ "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+ )
},
)
debug: str = field(
default="",
metadata={
- "help": "Whether or not to enable debug mode. Current options: "
- "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
- "`tpu_metrics_debug` (print debug metrics on TPU)."
+ "help": (
+ "Whether or not to enable debug mode. Current options: "
+ "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
+ "`tpu_metrics_debug` (print debug metrics on TPU)."
+ )
},
)
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
- eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
+ eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={
- "help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process."
+ "help": (
+ "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded"
+ " in the main process."
+ )
},
)
@@ -666,28 +767,66 @@ class TrainingArguments:
ignore_data_skip: bool = field(
default=False,
metadata={
- "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
+ "help": (
+ "When resuming training, whether or not to skip the first epochs and batches to get to the same"
+ " training data."
+ )
},
)
sharded_ddp: str = field(
default="",
metadata={
- "help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
- "should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
- "like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
- "with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
+ "help": (
+ "Whether or not to use sharded DDP training (in distributed training only). The base option should be"
+ " `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
+ " this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
+ " with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
+ ),
+ },
+ )
+ fsdp: str = field(
+ default="",
+ metadata={
+ "help": (
+ "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
+ " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
+ " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
+ " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
+ " auto_wrap` or `shard_grad_op auto_wrap`."
+ ),
+ },
+ )
+ fsdp_min_num_params: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is"
+ " passed)."
+ )
+ },
+ )
+ fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
+ "(useful only when `fsdp` flag is passed)."
+ )
},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
- "help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
+ "help": (
+ "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already"
+ " loaded json file as a dict"
+ )
},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
- optim: OptimizerNames = field(
+ optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
@@ -706,15 +845,19 @@ class TrainingArguments:
ddp_find_unused_parameters: Optional[bool] = field(
default=None,
metadata={
- "help": "When using distributed training, the value of the flag `find_unused_parameters` passed to "
- "`DistributedDataParallel`."
+ "help": (
+ "When using distributed training, the value of the flag `find_unused_parameters` passed to "
+ "`DistributedDataParallel`."
+ )
},
)
ddp_bucket_cap_mb: Optional[int] = field(
default=None,
metadata={
- "help": "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
- "`DistributedDataParallel`."
+ "help": (
+ "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
+ "`DistributedDataParallel`."
+ )
},
)
dataloader_pin_memory: bool = field(
@@ -733,14 +876,14 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
- hub_model_id: str = field(
+ hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
- hub_strategy: HubStrategy = field(
+ hub_strategy: Union[HubStrategy, str] = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
- hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field(
default=False,
@@ -754,21 +897,72 @@ class TrainingArguments:
# Deprecated arguments
fp16_backend: str = field(
default="auto",
- metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
+ metadata={
+ "help": "Deprecated. Use half_precision_backend instead",
+ "choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
+ },
)
- push_to_hub_model_id: str = field(
+ push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
- push_to_hub_organization: str = field(
+ push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
- push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ push_to_hub_token: Optional[str] = field(
+ default=None, metadata={"help": "The token to use to push to the Model Hub."}
+ )
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)
+ auto_find_batch_size: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to automatically decrease the batch size in half and rerun the training loop again each time"
+ " a CUDA Out-of-Memory was reached"
+ )
+ },
+ )
+ full_determinism: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed"
+ " training"
+ )
+ },
+ )
+ torchdynamo: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
+ " make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
+ " before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
+ " and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
+ " are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
+ " nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
+ ),
+ "choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
+ },
+ )
+ ray_scope: Optional[str] = field(
+ default="last",
+ metadata={
+ "help": (
+ 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
+ " will then use the last checkpoint of all trials, compare those, and select the best one. However,"
+ " other options are also available. See the Ray documentation"
+ " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
+ "#ray.tune.ExperimentAnalysis.get_best_trial)"
+ " for more options."
+ )
+ },
+ )
+
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
@@ -795,7 +989,8 @@ def __post_init__(self):
if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
- "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of š¤ Transformers. Use `IntervalStrategy` instead",
+ "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5"
+ " of š¤ Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
# Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
@@ -817,7 +1012,8 @@ def __post_init__(self):
self.eval_steps = self.logging_steps
else:
raise ValueError(
- f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or --logging_steps"
+ f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or"
+ " --logging_steps"
)
# logging_steps must be non-zero for logging_strategy that is other than 'no'
@@ -846,20 +1042,34 @@ def __post_init__(self):
if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
- "`fp16_backend` is deprecated and will be removed in version 5 of š¤ Transformers. Use `half_precision_backend` instead",
+ "`fp16_backend` is deprecated and will be removed in version 5 of š¤ Transformers. Use"
+ " `half_precision_backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
- if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available():
- raise ValueError("Your setup doesn't support bf16. You need Ampere GPU, torch>=1.10, cuda>=11.0")
+ if self.bf16 or self.bf16_full_eval:
+
+ if self.no_cuda and not is_torch_bf16_cpu_available():
+ # cpu
+ raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
+ elif not self.no_cuda and not is_torch_bf16_gpu_available():
+ # gpu
+ raise ValueError(
+ "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
+ )
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
+
+ if self.fp16_full_eval and self.bf16_full_eval:
+ raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
+
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
- " `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead"
+ " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
+ " `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
@@ -867,7 +1077,8 @@ def __post_init__(self):
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(
- "`--adafactor` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--optim adafactor` instead",
+ "`--adafactor` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--optim"
+ " adafactor` instead",
FutureWarning,
)
self.optim = OptimizerNames.ADAFACTOR
@@ -876,10 +1087,23 @@ def __post_init__(self):
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
- and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
+ and (self.fp16 or self.fp16_full_eval)
+ ):
+ raise ValueError(
+ "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
+ " (`--fp16_full_eval`) can only be used on CUDA devices."
+ )
+
+ if (
+ is_torch_available()
+ and (self.device.type != "cuda")
+ and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
+ and (self.device.type != "cpu")
+ and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
- "Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
+ "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
+ " (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if is_torch_available() and self.tf32 is not None:
@@ -914,7 +1138,8 @@ def __post_init__(self):
raise ValueError("warmup_ratio must lie in range [0,1]")
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
logger.info(
- "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
+ "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
+ " during training"
)
if isinstance(self.sharded_ddp, bool):
@@ -931,9 +1156,33 @@ def __post_init__(self):
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
+ if isinstance(self.fsdp, bool):
+ self.fsdp = "full_shard" if self.fsdp else ""
+ if isinstance(self.fsdp, str):
+ self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
+ if self.fsdp == [FSDPOption.OFFLOAD]:
+ raise ValueError(
+ "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
+ '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
+ )
+ elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp:
+ raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
+
+ if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
+ warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
+
+ if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
+ warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
+
+ if len(self.fsdp) > 0 and self.fsdp_min_num_params > 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
+ raise ValueError(
+ "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
+ )
+
if self.tpu_metrics_debug:
warnings.warn(
- "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--debug tpu_metrics_debug` instead",
+ "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of š¤ Transformers. Use"
+ " `--debug tpu_metrics_debug` instead",
FutureWarning,
)
self.debug += " tpu_metrics_debug"
@@ -944,6 +1193,8 @@ def __post_init__(self):
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
+ if not is_accelerate_available():
+ raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.")
from transformers.deepspeed import HfTrainerDeepSpeedConfig
# will be used later by the Trainer
@@ -1041,6 +1292,10 @@ def _setup_devices(self) -> "torch.device":
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
+ self.local_rank = get_int_from_env(
+ ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"],
+ self.local_rank,
+ )
if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"):
@@ -1048,7 +1303,36 @@ def _setup_devices(self) -> "torch.device":
"CPU distributed training backend is not properly set. "
"Please set '--xpu_backend' to either 'mpi' or 'ccl'."
)
- torch.distributed.init_process_group(backend=self.xpu_backend)
+ if self.xpu_backend == "ccl":
+ requires_backends(self, "oneccl_bind_pt")
+ if ccl_version >= "1.12":
+ import oneccl_bindings_for_pytorch # noqa: F401
+ else:
+ import torch_ccl # noqa: F401
+ if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
+ raise ValueError(
+ "CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
+ "Please use like 'export CCL_WORKER_COUNT = 1' to set."
+ )
+
+ # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
+ rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
+ size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
+ local_size = get_int_from_env(
+ ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
+ )
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(size)
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+ if not os.environ.get("MASTER_PORT", None):
+ os.environ["MASTER_PORT"] = "29500"
+ if not os.environ.get("MASTER_ADDR", None):
+ if local_size != size or self.xpu_backend != "mpi":
+ raise ValueError(
+ "Looks like distributed multinode run but MASTER_ADDR env not set, "
+ "please try exporting rank 0's hostname as MASTER_ADDR"
+ )
+ torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
@@ -1057,6 +1341,8 @@ def _setup_devices(self) -> "torch.device":
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
+ import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
+
dist.init_process_group(backend="smddp")
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
diff --git a/src/transformers/training_args_seq2seq.py b/src/transformers/training_args_seq2seq.py
index ef3ccdf26017..026dce81bcfd 100644
--- a/src/transformers/training_args_seq2seq.py
+++ b/src/transformers/training_args_seq2seq.py
@@ -51,14 +51,18 @@ class Seq2SeqTrainingArguments(TrainingArguments):
generation_max_length: Optional[int] = field(
default=None,
metadata={
- "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `max_length` value of the model configuration."
+ "help": (
+ "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `max_length` value of the model configuration."
+ )
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `num_beams` value of the model configuration."
+ "help": (
+ "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `num_beams` value of the model configuration."
+ )
},
)
diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py
index 4f3c41e2cab2..060b78e92205 100644
--- a/src/transformers/training_args_tf.py
+++ b/src/transformers/training_args_tf.py
@@ -14,7 +14,7 @@
import warnings
from dataclasses import dataclass, field
-from typing import Tuple
+from typing import Optional, Tuple
from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
@@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""
- tpu_name: str = field(
+ tpu_name: Optional[str] = field(
default=None,
metadata={"help": "Name of TPU"},
)
- tpu_zone: str = field(
+ tpu_zone: Optional[str] = field(
default=None,
metadata={"help": "Zone of TPU"},
)
- gcp_project: str = field(
+ gcp_project: Optional[str] = field(
default=None,
metadata={"help": "Name of Cloud TPU-enabled project"},
)
@@ -195,8 +195,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
# Set to float16 at first
if self.fp16:
- policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
- tf.keras.mixed_precision.experimental.set_policy(policy)
+ tf.keras.mixed_precision.set_global_policy("mixed_float16")
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
@@ -217,8 +216,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
- policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
- tf.keras.mixed_precision.experimental.set_policy(policy)
+ tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 6101a924f969..27276aa4946d 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -22,6 +22,7 @@
from packaging import version
from .. import __version__
+from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from .doc import (
add_code_sample_docstrings,
add_end_docstrings,
@@ -38,9 +39,11 @@
TensorType,
cached_property,
find_labels,
+ flatten_dict,
is_tensor,
to_numpy,
to_py_obj,
+ working_or_temp_dir,
)
from .hub import (
CLOUDFRONT_DISTRIB_PREFIX,
@@ -57,23 +60,17 @@
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
- cached_path,
+ cached_file,
default_cache_path,
define_sagemaker_information,
- filename_to_url,
get_cached_models,
get_file_from_repo,
- get_from_cache,
get_full_repo_name,
- get_list_of_files,
has_file,
- hf_bucket_url,
- http_get,
http_user_agent,
- is_local_clone,
is_offline_mode,
- is_remote_url,
- url_to_filename,
+ move_cache,
+ send_example_telemetry,
)
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
@@ -83,7 +80,10 @@
USE_TF,
USE_TORCH,
DummyObject,
+ OptionalDependencyNotAvailable,
_LazyModule,
+ ccl_version,
+ is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
@@ -93,6 +93,7 @@
is_flax_available,
is_ftfy_available,
is_in_notebook,
+ is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@@ -104,6 +105,7 @@
is_pytesseract_available,
is_pytorch_quantization_available,
is_rjieba_available,
+ is_sacremoses_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_scatter_available,
@@ -114,19 +116,24 @@
is_spacy_available,
is_speech_available,
is_tensorflow_probability_available,
+ is_tensorflow_text_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
+ is_torch_bf16_cpu_available,
+ is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_onnx_dict_inputs_support_available,
+ is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
+ is_torchdynamo_available,
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
@@ -140,8 +147,10 @@
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
TF2_WEIGHTS_NAME = "tf_model.h5"
+TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
+FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
@@ -168,8 +177,6 @@ def check_min_version(min_version):
error_message += f" but the version found is {__version__}.\n"
raise ImportError(
error_message
- + (
- "Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other "
- "versions of HuggingFace Transformers."
- )
+ + "Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other "
+ "versions of HuggingFace Transformers."
)
diff --git a/src/transformers/utils/constants.py b/src/transformers/utils/constants.py
new file mode 100644
index 000000000000..af2e48ab0a8b
--- /dev/null
+++ b/src/transformers/utils/constants.py
@@ -0,0 +1,4 @@
+IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
+IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
+IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
+IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py
index 8f0caf825bba..6761dec9c969 100644
--- a/src/transformers/utils/doc.py
+++ b/src/transformers/utils/doc.py
@@ -428,8 +428,7 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
```
```python
- >>> with processor.as_target_processor():
- ... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
@@ -849,8 +848,7 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
```
```python
- >>> with processor.as_target_processor():
- ... inputs["labels"] = processor(dataset[0]["text"], return_tensors="tf").input_ids
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py
index a6c6e7926da1..953808dab8ad 100644
--- a/src/transformers/utils/dummy_flax_objects.py
+++ b/src/transformers/utils/dummy_flax_objects.py
@@ -725,6 +725,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxLongT5Model(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxLongT5PreTrainedModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxMarianModel(metaclass=DummyObject):
_backends = ["flax"]
@@ -781,6 +802,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxMT5EncoderModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxMT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
@@ -795,6 +823,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
@@ -928,6 +977,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxT5EncoderModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 112759671bbf..d636be655af2 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -234,6 +234,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class TypicalLogitsWarper(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"]
@@ -399,9 +406,15 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
+MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = None
+
+
MODEL_FOR_VISION_2_SEQ_MAPPING = None
+MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
+
+
MODEL_MAPPING = None
@@ -562,6 +575,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class AutoModelForVideoClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class AutoModelForVision2Seq(metaclass=DummyObject):
_backends = ["torch"]
@@ -569,6 +589,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class AutoModelWithLMHead(metaclass=DummyObject):
_backends = ["torch"]
@@ -959,6 +986,44 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class BloomForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1098,6 +1163,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class CodeGenForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CodeGenModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CodeGenPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1216,6 +1305,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+CVT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class CvtForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CvtModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CvtPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1406,6 +1519,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class DebertaV2ForMultipleChoice(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class DebertaV2ForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
@@ -1780,6 +1900,58 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class FlavaForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaImageCodebook(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaImageModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaMultimodalModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaTextModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
FNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2052,6 +2224,37 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs):
requires_backends(load_tf_weights_in_gpt_neo, ["torch"])
+GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class GPTNeoXForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXLayer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2090,6 +2293,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class GroupViTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GroupViTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GroupViTTextModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GroupViTVisionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2284,215 +2518,343 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
+LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = None
-class LEDForConditionalGeneration(metaclass=DummyObject):
+class LayoutLMv3ForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LEDForQuestionAnswering(metaclass=DummyObject):
+class LayoutLMv3ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LEDForSequenceClassification(metaclass=DummyObject):
+class LayoutLMv3ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LEDModel(metaclass=DummyObject):
+class LayoutLMv3Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LEDPreTrainedModel(metaclass=DummyObject):
+class LayoutLMv3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
-class LongformerForMaskedLM(metaclass=DummyObject):
+class LEDForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerForMultipleChoice(metaclass=DummyObject):
+class LEDForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerForQuestionAnswering(metaclass=DummyObject):
+class LEDForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerForSequenceClassification(metaclass=DummyObject):
+class LEDModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerForTokenClassification(metaclass=DummyObject):
+class LEDPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerModel(metaclass=DummyObject):
+LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LevitForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerPreTrainedModel(metaclass=DummyObject):
+class LevitForImageClassificationWithTeacher(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LongformerSelfAttention(metaclass=DummyObject):
+class LevitModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None
-
-
-class LukeForEntityClassification(metaclass=DummyObject):
+class LevitPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LukeForEntityPairClassification(metaclass=DummyObject):
+LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LongformerForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LukeForEntitySpanClassification(metaclass=DummyObject):
+class LongformerForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LukeForMaskedLM(metaclass=DummyObject):
+class LongformerForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LukeModel(metaclass=DummyObject):
+class LongformerForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LukePreTrainedModel(metaclass=DummyObject):
+class LongformerForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertEncoder(metaclass=DummyObject):
+class LongformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertForPreTraining(metaclass=DummyObject):
+class LongformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertForQuestionAnswering(metaclass=DummyObject):
+class LongformerSelfAttention(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertModel(metaclass=DummyObject):
+LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LongT5EncoderModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertPreTrainedModel(metaclass=DummyObject):
+class LongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertVisualFeatureEncoder(metaclass=DummyObject):
+class LongT5Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class LxmertXLayer(metaclass=DummyObject):
+class LongT5PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = None
+LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None
-class M2M100ForConditionalGeneration(metaclass=DummyObject):
+class LukeForEntityClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class M2M100Model(metaclass=DummyObject):
+class LukeForEntityPairClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class M2M100PreTrainedModel(metaclass=DummyObject):
+class LukeForEntitySpanClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeForMaskedLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeForMultipleChoice(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukeModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LukePreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertEncoder(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertVisualFeatureEncoder(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LxmertXLayer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class M2M100ForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class M2M100Model(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class M2M100PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -2586,6 +2948,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class MCTCTForCTC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MCTCTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MCTCTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2757,6 +3143,37 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
requires_backends(load_tf_weights_in_mobilebert, ["torch"])
+MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class MobileViTForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MobileViTForSemanticSegmentation(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MobileViTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MobileViTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
MPNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2837,6 +3254,117 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class MvpForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MvpForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MvpForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MvpForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MvpModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MvpPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class NezhaForMaskedLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForMultipleChoice(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForNextSentencePrediction(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class NezhaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2938,6 +3466,75 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs):
requires_backends(load_tf_weights_in_openai_gpt, ["torch"])
+OPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class OPTForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OPTForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OPTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class OwlViTForObjectDetection(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OwlViTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OwlViTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OwlViTTextModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OwlViTVisionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class PegasusForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
@@ -3785,6 +4382,13 @@ def __init__(self, *args, **kwargs):
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+class SplinterForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class SplinterForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
@@ -3903,6 +4507,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class Swinv2ForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Swinv2ForMaskedImageModeling(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Swinv2Model(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Swinv2PreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -3938,6 +4573,23 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"])
+TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TrajectoryTransformerModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class TrajectoryTransformerPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -4111,6 +4763,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class VideoMAEForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class VideoMAEForVideoClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class VideoMAEModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class VideoMAEPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
VILT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -4142,6 +4825,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class ViltForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ViltLayer(metaclass=DummyObject):
_backends = ["torch"]
@@ -4357,6 +5047,58 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class Wav2Vec2ConformerForAudioFrameClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForCTC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForXVector(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py
index 00989dc0d12a..69f0bdcb7b1a 100644
--- a/src/transformers/utils/dummy_sentencepiece_objects.py
+++ b/src/transformers/utils/dummy_sentencepiece_objects.py
@@ -115,6 +115,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
+class NllbTokenizer(metaclass=DummyObject):
+ _backends = ["sentencepiece"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["sentencepiece"])
+
+
class PegasusTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
diff --git a/src/transformers/utils/dummy_speech_objects.py b/src/transformers/utils/dummy_speech_objects.py
index 721fe80a7925..ae5589292a4c 100644
--- a/src/transformers/utils/dummy_speech_objects.py
+++ b/src/transformers/utils/dummy_speech_objects.py
@@ -3,6 +3,13 @@
from ..utils import DummyObject, requires_backends
+class MCTCTFeatureExtractor(metaclass=DummyObject):
+ _backends = ["speech"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["speech"])
+
+
class Speech2TextFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
diff --git a/src/transformers/utils/dummy_tensorflow_text_objects.py b/src/transformers/utils/dummy_tensorflow_text_objects.py
new file mode 100644
index 000000000000..691774bb6bbf
--- /dev/null
+++ b/src/transformers/utils/dummy_tensorflow_text_objects.py
@@ -0,0 +1,10 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+from ..utils import DummyObject, requires_backends
+
+
+class TFBertTokenizer(metaclass=DummyObject):
+ _backends = ["tensorflow_text"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tensorflow_text"])
diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py
index e4a47290b5f5..6df601ca646a 100644
--- a/src/transformers/utils/dummy_tf_objects.py
+++ b/src/transformers/utils/dummy_tf_objects.py
@@ -261,6 +261,9 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
+TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
+
+
TF_MODEL_FOR_MASKED_LM_MAPPING = None
@@ -276,6 +279,9 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
+TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None
+
+
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
@@ -335,6 +341,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFAutoModelForNextSentencePrediction(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFAutoModelForPreTraining(metaclass=DummyObject):
_backends = ["tf"]
@@ -749,6 +762,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFData2VecVisionForSemanticSegmentation(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFData2VecVisionModel(metaclass=DummyObject):
_backends = ["tf"]
@@ -853,6 +873,44 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFDeiTForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFDeiTForImageClassificationWithTeacher(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFDeiTForMaskedImageModeling(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFDeiTModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFDeiTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1609,6 +1667,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
@@ -1658,6 +1737,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFRegNetForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFRegNetModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFRegNetPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1724,6 +1827,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFResNetForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFResNetModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFResNetPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1856,6 +1983,44 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFSegformerDecodeHead(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSegformerForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSegformerForSemanticSegmentation(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSegformerModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSegformerPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1880,6 +2045,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFSwinForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinForMaskedImageModeling(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py
index 12cec6a4a260..755be5c48ae5 100644
--- a/src/transformers/utils/dummy_tokenizers_objects.py
+++ b/src/transformers/utils/dummy_tokenizers_objects.py
@@ -52,6 +52,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class BloomTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class CamembertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -66,6 +73,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class CodeGenTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class ConvBertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -150,6 +164,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class GPTNeoXTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class HerbertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -171,6 +192,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class LayoutLMv3TokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class LayoutXLMTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -234,6 +262,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class MvpTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
+class NllbTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class OpenAIGPTTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 8ba819156143..30228e022222 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -59,6 +59,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class FlavaFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class FlavaProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class GLPNFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
@@ -80,14 +94,14 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class LayoutLMv2Processor(metaclass=DummyObject):
+class LayoutLMv3FeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class LayoutXLMProcessor(metaclass=DummyObject):
+class LevitFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
@@ -101,6 +115,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class MobileViTFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class OwlViTFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class PerceiverFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
@@ -122,6 +150,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class VideoMAEFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class ViltFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index d112a7151680..2198928eadb3 100644
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -14,39 +14,40 @@
# limitations under the License.
import builtins
+import collections
import functools
import inspect
import math
+import operator
import random
import warnings
-from copy import deepcopy
-from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
+from torch.fx.proxy import ParameterProxy
-from .. import (
- CONFIG_MAPPING,
- MODEL_FOR_CAUSAL_LM_MAPPING,
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
- MODEL_FOR_MASKED_LM_MAPPING,
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
- MODEL_FOR_PRETRAINING_MAPPING,
- MODEL_FOR_QUESTION_ANSWERING_MAPPING,
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
- MODEL_MAPPING,
- GPT2DoubleHeadsModel,
- PretrainedConfig,
- PreTrainedModel,
- XLNetForQuestionAnswering,
- logging,
-)
+from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
+from ..models.auto.modeling_auto import (
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_CTC_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
+ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
+ MODEL_FOR_PRETRAINING_MAPPING_NAMES,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_MAPPING_NAMES,
+)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata
@@ -54,24 +55,27 @@
logger = logging.get_logger(__name__)
-def _generate_supported_model_classes(
+def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
-) -> List[Type[PreTrainedModel]]:
+) -> List[str]:
- model_config_class = CONFIG_MAPPING[model_name]
task_mapping = {
- "default": MODEL_MAPPING,
- "pretraining": MODEL_FOR_PRETRAINING_MAPPING,
- "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
- "masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
- "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
- "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
- "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
- "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
- "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
- "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
- "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ "default": MODEL_MAPPING_NAMES,
+ "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
+ "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
+ "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
+ "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
+ "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
+ "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
}
if supported_tasks is None:
@@ -79,88 +83,122 @@ def _generate_supported_model_classes(
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
- model_classes = []
+ model_class_names = []
for task in supported_tasks:
- model_class = task_mapping[task].get(model_config_class, None)
- if model_class:
- model_classes.append(model_class)
+ class_name = task_mapping[task].get(model_name, None)
+ if class_name:
+ model_class_names.append(class_name)
- return model_classes
+ return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert",
+ "bart",
"bert",
+ "blenderbot",
+ "blenderbot-small",
+ "bloom",
+ "clip",
+ "deberta",
+ "deberta-v2",
"distilbert",
- "mobilebert",
"electra",
- "megatron-bert",
"gpt2",
- "gptj",
"gpt_neo",
- "t5",
+ "gptj",
+ "hubert",
+ "layoutlm",
+ "lxmert",
+ "m2m_100",
+ "marian",
+ "mbart",
+ "megatron-bert",
+ "mobilebert",
+ "mt5",
+ "nezha",
+ "opt",
+ "pegasus",
+ "plbart",
"roberta",
- # TODO: add support for them as it should be quite easy to do so (small blocking issues).
- # "layoutlm",
- # "xlnet",
+ "speech_to_text",
+ "speech_to_text_2",
+ "swin",
+ "t5",
+ "trocr",
+ "vit",
+ "xglm",
+ # "xlnet",
]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
- _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
+ _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else:
- _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
+ _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [
- GPT2DoubleHeadsModel,
+ "CLIPTextModel",
+ "CLIPVisionModel",
+ "GPT2DoubleHeadsModel",
+ "Speech2Text2Decoder",
+ "TrOCRDecoder",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
-_SUPPORTED_MODELS = tuple(
- sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
-)
+_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
-def embedding_override(self, input):
+def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
-def torch_nn_layernorm_override(self, input):
+def torch_nn_functional_embedding(
+ input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
+):
+ return torch.empty(*input.shape, weight.shape[-1], device="meta")
+
+
+def torch_nn_layernorm(self, input):
return input
-def torch_nn_linear_override(self, input):
+def torch_nn_groupnorm(self, input):
+ return input
+
+
+def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
-def torch_relu_override(x):
+def torch_relu(x):
return x
-def torch_nn_relu_override(self, x):
+def torch_nn_relu(self, x):
return x
-def torch_nn_functional_relu_override(x, inplace=False):
+def torch_nn_functional_relu(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
-def torch_where_override(condition, x, y):
+def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
-def torch_abs_override(input, *, out=None):
- if out is None:
+def torch_abs(input, *, out=None):
+ if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
-def torch_arange_override(*args, **kwargs):
+def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
@@ -170,12 +208,18 @@ def torch_arange_override(*args, **kwargs):
start, end = args
else:
start, end, step = args
+ if isinstance(start, float):
+ start = int(start)
+ if isinstance(end, float):
+ start = int(end)
+ if isinstance(step, float):
+ step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
-def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
+def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
@@ -189,7 +233,7 @@ def torch_cat_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(final_shape, device="meta")
-def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
+def torch_stack(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
@@ -201,7 +245,7 @@ def torch_stack_override(tensors, dim=None, axis=None, *, out=None):
return torch.empty(shape, device="meta")
-def torch_add_override(input, other, *, alpha=1, out=None):
+def torch_add(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
@@ -215,15 +259,15 @@ def torch_add_override(input, other, *, alpha=1, out=None):
return torch.empty(shape, device="meta")
-def torch_mul_override(input, other, *, out=None):
- return torch_add_override(input, other, out=out)
+def torch_mul(input, other, *, out=None):
+ return torch_add(input, other, out=out)
-def torch_tensor_mul_override(self, other):
- return torch_mul_override(self, other)
+def torch_tensor_mul(self, other):
+ return torch_mul(self, other)
-def torch_matmul_override(input, other, *, out=None):
+def torch_matmul(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
@@ -259,7 +303,31 @@ def torch_matmul_override(input, other, *, out=None):
return torch.empty(*shape, device="meta")
-def torch_tensor_repeat_override(self, *sizes):
+def torch_bmm(input, mat2, *, out=None):
+ if out is not None:
+ raise ValueError("Don't support in-place bmm for MetaTensor analysis")
+ batch_size, n, m = input.shape
+ _, _, p = mat2.shape
+ return torch.empty(batch_size, n, p, device="meta")
+
+
+def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
+ if out is not None:
+ raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
+ return torch_bmm(batch1, batch2)
+
+
+def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
+ return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
+
+
+def torch_einsum(equation, *operands):
+ # TODO: infer shape without performing the computation, this might be quite hard.
+ concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
+ return torch.einsum(equation, *concrete_operands).to("meta")
+
+
+def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
@@ -273,7 +341,106 @@ def torch_index_select(input, dim, index, *, out=None):
def torch_tensor_index_select(self, dim, index):
- return torch_tensor_index_select(self, dim, index)
+ return torch_index_select(self, dim, index)
+
+
+def torch_roll(input, shifts, dims=None):
+ return input
+
+
+def torch_flip(input, dims):
+ return input
+
+
+def torch_tensor_flip(self, dims):
+ return self
+
+
+def torch_nn_conv1d(self, input):
+ l_in = input.shape[-1]
+ shape = None
+ padding = self.padding
+ if padding == "valid":
+ padding = (0, 0)
+ if padding == "same":
+ shape = list(input.shape)
+ if shape is None:
+ shape = list(input.shape)
+ l_out = math.floor(
+ (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ shape[-1] = l_out
+ shape[-2] = self.out_channels
+ return torch.empty(shape, device="meta")
+
+
+def torch_nn_conv2d(self, input):
+ h_in, w_in = input.shape[-2:]
+ shape = None
+ padding = self.padding
+ if padding == "valid":
+ padding = (0, 0)
+ if padding == "same":
+ shape = list(input.shape)
+ if shape is None:
+ shape = list(input.shape)
+ h_out = math.floor(
+ (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ shape[-2:] = [h_out, w_out]
+ shape[-3] = self.out_channels
+ return torch.empty(shape, device="meta")
+
+
+def torch_squeeze(input, dim=None):
+ shape = list(input.shape)
+ if dim is not None:
+ if dim < 0:
+ dim = input.dim() + dim
+ if shape[dim] == 1:
+ shape.pop(dim)
+ else:
+ new_shape = []
+ for dim_value in shape:
+ if dim_value == 1:
+ continue
+ new_shape.append(dim_value)
+ shape = new_shape
+ return torch.empty(shape, device="meta")
+
+
+def torch_tensor_squeeze(self, dim=None):
+ return torch_squeeze(self, dim)
+
+
+def torch_unsqueeze(input, dim):
+ shape = list(input.shape)
+ if dim < 0:
+ dim = input.dim() + 1 + dim
+ shape.insert(dim, 1)
+ return torch.empty(shape, device="meta")
+
+
+def torch_tensor_unsqueeze(self, dim):
+ return torch_unsqueeze(self, dim)
+
+
+def torch_unique_consecutive(input, **kwargs):
+ output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
+ if isinstance(output, torch.Tensor):
+ return output.to("meta")
+ else:
+ return tuple(map(output, lambda x: x.to("meta")))
+
+
+def torch_nn_functional_one_hot(tensor, num_classes=-1):
+ if num_classes < 0:
+ raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
+ shape = list(tensor.shape) + [num_classes]
+ return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target):
@@ -300,29 +467,65 @@ def torch_nn_bcewithlogitsloss(self, input, target):
return torch.empty(shape, device="meta")
+def operator_getitem(a, b):
+ def to_concrete(t):
+ if isinstance(t, torch.Tensor):
+ concrete = torch.ones_like(t, device="cpu")
+ if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
+ concrete = concrete.to(torch.int64)
+ return concrete
+ return t
+
+ if isinstance(a, torch.Tensor):
+ # TODO: infer shape without performing the computation.
+ if isinstance(b, tuple):
+ b = tuple(map(to_concrete, b))
+ else:
+ b = to_concrete(b)
+ return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
+ return operator.getitem(a, b)
+
+
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
- torch.nn.Embedding: embedding_override,
- torch.nn.LayerNorm: torch_nn_layernorm_override,
- torch.nn.Linear: torch_nn_linear_override,
- torch.relu: torch_relu_override,
- torch.nn.functional.relu: torch_nn_functional_relu_override,
- torch.nn.ReLU: torch_nn_relu_override,
- torch.where: torch_where_override,
- torch.abs: torch_abs_override,
- torch.arange: torch_arange_override,
- torch.cat: torch_cat_override,
- torch.stack: torch_stack_override,
- torch.add: torch_add_override,
- torch.mul: torch_mul_override,
- torch.Tensor.mul: torch_tensor_mul_override,
- torch.matmul: torch_matmul_override,
- torch.Tensor.repeat: torch_tensor_repeat_override,
- # TODO: those might not be needed.
- # torch.index_select: torch_index_select,
- # torch.Tensor.index_select: torch_tensor_index_select,
+ torch.nn.Embedding: torch_nn_embedding,
+ torch.nn.functional.embedding: torch_nn_functional_embedding,
+ torch.nn.LayerNorm: torch_nn_layernorm,
+ torch.nn.GroupNorm: torch_nn_groupnorm,
+ torch.nn.Linear: torch_nn_linear,
+ torch.relu: torch_relu,
+ torch.nn.functional.relu: torch_nn_functional_relu,
+ torch.nn.ReLU: torch_nn_relu,
+ torch.where: torch_where,
+ torch.abs: torch_abs,
+ torch.arange: torch_arange,
+ torch.cat: torch_cat,
+ torch.stack: torch_stack,
+ torch.add: torch_add,
+ torch.mul: torch_mul,
+ torch.Tensor.mul: torch_tensor_mul,
+ torch.matmul: torch_matmul,
+ torch.bmm: torch_bmm,
+ torch.baddbmm: torch_baddbmm,
+ torch.Tensor.baddbmm: torch_tensor_baddbmm,
+ torch.einsum: torch_einsum,
+ torch.Tensor.repeat: torch_tensor_repeat,
+ torch.roll: torch_roll,
+ torch.flip: torch_flip,
+ torch.Tensor.flip: torch_tensor_flip,
+ torch.index_select: torch_index_select,
+ torch.Tensor.index_select: torch_tensor_index_select,
+ torch.nn.Conv1d: torch_nn_conv1d,
+ torch.nn.Conv2d: torch_nn_conv2d,
+ torch.squeeze: torch_squeeze,
+ torch.Tensor.squeeze: torch_tensor_squeeze,
+ torch.unsqueeze: torch_unsqueeze,
+ torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
+ torch.unique_consecutive: torch_unique_consecutive,
+ torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
+ operator.getitem: operator_getitem,
}
@@ -340,7 +543,6 @@ def shape(self):
@property
def dtype(self):
- return self.tracer.root.dtype
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata.dtype
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
@@ -368,11 +570,12 @@ def __getattr__(self, k):
# we peephole optimize to the method invocation
return HFAttribute(self, k)
+ def __setitem__(self, indices, values):
+ return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
+
def __contains__(self, key):
- # To handle cases such as :
- # `"some_key" in kwargs`
- if self.node.op == "placeholder":
- return False
+ if hasattr(self, "_metadata") and self._metadata is not None:
+ return key in self._metadata
return super().__contains__(key)
@@ -446,14 +649,14 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy.
"""
+ # Feature flag for proxying accesses to buffer values
+ proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
- _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
+ _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
- def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
+ def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
- super().__init__(
- autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
- )
+ super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
@@ -466,22 +669,24 @@ def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
- model_class = model.__class__
+ # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
+ # from pickle, or from the "__class__" attribute in the general case.
+ model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device
inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0]
- if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
- XLNetForQuestionAnswering,
+ elif model_class_name in [
+ *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
+ "XLNetForQuestionAnswering",
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
+ elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set "
@@ -505,23 +710,72 @@ def _generate_dummy_input(
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
- *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
+ elif model_class_name in [
+ *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_PRETRAINING_MAPPING),
- *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
- *get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
- *get_values(MODEL_FOR_MASKED_LM_MAPPING),
- *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
- GPT2DoubleHeadsModel,
+ elif model_class_name in [
+ *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
+ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
+ "GPT2DoubleHeadsModel",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
- raise NotImplementedError(f"{model_class} not supported yet.")
-
+ raise NotImplementedError(f"{model_class_name} not supported yet.")
+ elif "pixel_values" in input_name:
+ batch_size = shape[0]
+ image_size = getattr(model.config, "image_size", None)
+ if image_size is None:
+ if hasattr(model.config, "vision_config"):
+ image_size = model.config.vision_config.image_size
+ elif hasattr(model.config, "encoder"):
+ image_size = model.config.encoder.image_size
+ else:
+ raise AttributeError('Could not find the "image_size" field in the model config')
+
+ # If no num_channels is in the config, use some arbitrary value.
+ num_channels = getattr(model.config, "num_channels", 3)
+ if not isinstance(image_size, collections.abc.Iterable):
+ image_size = (image_size, image_size)
+ height, width = image_size
+ inputs_dict[input_name] = torch.zeros(
+ batch_size, num_channels, height, width, dtype=torch.float32, device=device
+ )
+ elif "bbox" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
+ elif "input_features" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
+ )
+ elif "visual_feats" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_feat_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "visual_pos" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_pos_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "inputs" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
+ elif "input_values" in input_name:
+ batch_size, _ = shape
+ # Generating big sequence length for audio inputs.
+ seq_length = _generate_random_int(low=10000, high=20000)
+ inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
@@ -553,6 +807,8 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
+ if isinstance(meta_out, torch.Tensor):
+ meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
@@ -598,7 +854,38 @@ def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
- return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
+ # return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
+ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
+ for n, p in collection_to_search:
+ if attr_val is p:
+ if n not in parameter_proxy_cache:
+ kwargs = {}
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ parameter_proxy_cache[n] = val_proxy
+ return parameter_proxy_cache[n]
+ return None
+
+ if isinstance(attr_val, torch.nn.Parameter):
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
+ if maybe_parameter_proxy is not None:
+ return maybe_parameter_proxy
+
+ if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
+ if maybe_buffer_proxy is not None:
+ return maybe_buffer_proxy
+
+ return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
@@ -609,15 +896,49 @@ def proxy(self, node):
def trace(
self,
- root: PreTrainedModel,
+ root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
- method_names: Optional[Iterable[str]] = None,
+ dummy_inputs: Optional[Dict[str, Any]] = None,
+ complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
+ """
+ Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
+ `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
+ the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
+ `torch.nn.Module` instance to use as the root and add embedded constants to.
+
+ Args:
+ root (`torch.nn.Module` or `Callable`):
+ Either a `torch.nn.Module`` or a function to be traced through. If root is not a
+ [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
+ concrete_args (`Dict[str, Any], *optional*):
+ Concrete arguments that should not be treated as Proxies
+ dummy_inputs (`Dict[str, Any]`, *optional*):
+ The dummy inputs needed to handle data-dependent control-flow if `root` is not a
+ [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
+ [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
+ complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
+ If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
+ `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
+
+ Returns:
+ `torch.fx.Graph`:
+ A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
+
+ """
+ sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
if concrete_args is None:
concrete_args = {}
- sig = inspect.signature(root.forward)
+ if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
+ for param in sig.parameters.values():
+ if param.name in dummy_inputs:
+ continue
+ if param.default is inspect.Parameter.empty:
+ raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
+ concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs})
+
input_names = sig.parameters.keys() - concrete_args.keys()
# Creating a random input shape to generate dummy inputs.
@@ -625,15 +946,31 @@ def trace(
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
- if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
- inputs = {}
+ inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
for input_name in input_names:
- inputs.update(self._generate_dummy_input(root, input_name, shape))
-
- concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
+ if input_name in inputs:
+ continue
+ # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
+ # be able to use HFTracer._generate_dummy_input.
+ if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"):
+ inputs.update(self._generate_dummy_input(root, input_name, shape))
+ else:
+ raise RuntimeError(
+ f"Could not generate input named {input_name} for because root is not a"
+ " transformers.PreTrainedModel."
+ )
+
+ concrete_metas = {
+ input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
+ for input_name, input_ in inputs.items()
+ }
+ for param in sig.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
+ concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
@@ -650,18 +987,32 @@ def trace(
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
- # TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
- # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
+ # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
+ # It cannot infer on the attributes and methods the input should have, and fails.
+ node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
- self.graph.erase_node(node)
+ to_visit = [node]
+ to_delete = collections.OrderedDict()
+ while to_visit:
+ n = to_visit.pop(0)
+ to_delete[n] = None
+ to_visit += list(n.users.keys())
+
+ for user in reversed(to_delete.keys()):
+ self.graph.erase_node(user)
+
+ # TODO: solves GraphModule creation.
+ # Without this, return type annotation "Tuple" is causing code execution failure.
+ if node.op == "output":
+ node.type = None
return self.graph
@@ -719,9 +1070,32 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
)
+def get_concrete_args(model: nn.Module, input_names: List[str]):
+ sig = inspect.signature(model.forward)
+
+ if not (set(input_names) <= set(sig.parameters.keys())):
+ formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
+ formatted_allowed_input_names = ", ".join(sig.parameters.keys())
+ raise ValueError(
+ f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
+ f" {formatted_allowed_input_names}"
+ )
+
+ return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
+
+
+def check_if_model_is_supported(model: PreTrainedModel):
+ if model.__class__.__name__ not in _SUPPORTED_MODELS:
+ supported_model_names = ", ".join(_SUPPORTED_MODELS)
+ raise NotImplementedError(
+ f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
+ )
+
+
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
+ disable_check: bool = False,
) -> GraphModule:
"""
@@ -732,6 +1106,8 @@ def symbolic_trace(
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
+ disable_check (`bool`, *optional*, defaults to `False`):
+ If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
@@ -747,26 +1123,21 @@ def symbolic_trace(
if input_names is None:
input_names = model.dummy_inputs.keys()
- sig = inspect.signature(model.forward)
- concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
+ input_names = list(input_names)
+ concrete_args = get_concrete_args(model, input_names)
- if not isinstance(model, _SUPPORTED_MODELS):
- supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
- raise NotImplementedError(
- f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
- )
+ if not disable_check:
+ check_if_model_is_supported(model)
# Tracing.
tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
- # Copy all the original attributes to the traced GraphModule.
- regular_module_attributes = dir(nn.Module())
- for name in dir(model):
- attr = getattr(model, name)
- if name.startswith("_") or name in regular_module_attributes:
- continue
- setattr(traced, name, deepcopy(attr))
+ traced.config = model.config
+ # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
+ # _generate_dummy_input, where the model class is needed.
+ traced.class_for_deserialization = model.__class__
+ traced.device = model.device
return traced
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index bea5b3dd4775..9e8ae759d92f 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -16,8 +16,10 @@
"""
import inspect
+import tempfile
from collections import OrderedDict, UserDict
-from contextlib import ExitStack
+from collections.abc import MutableMapping
+from contextlib import ExitStack, contextmanager
from dataclasses import fields
from enum import Enum
from typing import Any, ContextManager, List, Tuple
@@ -239,7 +241,7 @@ def to_tuple(self) -> Tuple[Any]:
return tuple(self[k] for k in self.keys())
-class ExplicitEnum(Enum):
+class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
"""
@@ -310,3 +312,26 @@ def find_labels(model_class):
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
+
+
+def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
+ """Flatten a nested dict into a single level dict."""
+
+ def _flatten_dict(d, parent_key="", delimiter="."):
+ for k, v in d.items():
+ key = str(parent_key) + delimiter + str(k) if parent_key else k
+ if v and isinstance(v, MutableMapping):
+ yield from flatten_dict(v, key, delimiter=delimiter).items()
+ else:
+ yield key, v
+
+ return dict(_flatten_dict(d, parent_key, delimiter))
+
+
+@contextmanager
+def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
+ if use_temp_dir:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ yield tmp_dir
+ else:
+ yield working_dir
diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py
index 7386fe34f521..07164e735db9 100644
--- a/src/transformers/utils/hub.py
+++ b/src/transformers/utils/hub.py
@@ -14,33 +14,36 @@
"""
Hub utilities: utilities related to download and cache models
"""
-import copy
-import fnmatch
-import io
import json
import os
+import re
import shutil
-import subprocess
import sys
-import tarfile
-import tempfile
+import traceback
import warnings
from contextlib import contextmanager
-from functools import partial
-from hashlib import sha256
from pathlib import Path
-from typing import BinaryIO, Dict, List, Optional, Tuple, Union
-from urllib.parse import urlparse
+from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4
-from zipfile import ZipFile, is_zipfile
+import huggingface_hub
import requests
-from filelock import FileLock
-from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
+from huggingface_hub import (
+ CommitOperationAdd,
+ HfFolder,
+ create_commit,
+ create_repo,
+ hf_hub_download,
+ hf_hub_url,
+ whoami,
+)
+from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
from . import __version__, logging
+from .generic import working_or_temp_dir
from .import_utils import (
ENV_VARS_TRUE_VALUES,
_tf_version,
@@ -66,7 +69,7 @@ def is_offline_mode():
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
-default_cache_path = os.path.join(hf_cache_home, "transformers")
+default_cache_path = os.path.join(hf_cache_home, "hub")
# Onetime move from the old location to the new one if no ENV variable has been set.
if (
@@ -77,17 +80,18 @@ def is_offline_mode():
and "TRANSFORMERS_CACHE" not in os.environ
):
logger.warning(
- "In Transformers v4.0.0, the default path to cache downloaded models changed from "
- "'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
- "and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
- "'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
- "only see this message once."
+ "In Transformers v4.0.0, the default path to cache downloaded models changed from"
+ " '~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have"
+ " overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
+ " '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
+ " only see this message once."
)
shutil.move(old_default_cache_path, default_cache_path)
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
-TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
+HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
+TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex
@@ -97,7 +101,7 @@ def is_offline_mode():
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
-_default_endpoint = "https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"
+_default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
@@ -109,93 +113,7 @@ def is_offline_mode():
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
-
-
-def is_remote_url(url_or_filename):
- parsed = urlparse(url_or_filename)
- return parsed.scheme in ("http", "https")
-
-
-def hf_bucket_url(
- model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
-) -> str:
- """
- Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
- to Cloudfront (a Content Delivery Network, or CDN) for large files.
-
- Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
- bandwidth costs).
-
- Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
- because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
- in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
- can't ever be stale.
-
- In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
- its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
- are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
- """
- if subfolder is not None:
- filename = f"{subfolder}/{filename}"
-
- if mirror:
- if mirror in ["tuna", "bfsu"]:
- raise ValueError("The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument.")
- legacy_format = "/" not in model_id
- if legacy_format:
- return f"{mirror}/{model_id}-{filename}"
- else:
- return f"{mirror}/{model_id}/{filename}"
-
- if revision is None:
- revision = "main"
- return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
-
-
-def url_to_filename(url: str, etag: Optional[str] = None) -> str:
- """
- Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
- delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
- identify it as a HDF5 file (see
- https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
- """
- url_bytes = url.encode("utf-8")
- filename = sha256(url_bytes).hexdigest()
-
- if etag:
- etag_bytes = etag.encode("utf-8")
- filename += "." + sha256(etag_bytes).hexdigest()
-
- if url.endswith(".h5"):
- filename += ".h5"
-
- return filename
-
-
-def filename_to_url(filename, cache_dir=None):
- """
- Return the url and etag (which may be `None`) stored for *filename*. Raise `EnvironmentError` if *filename* or its
- stored metadata do not exist.
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- if isinstance(cache_dir, Path):
- cache_dir = str(cache_dir)
-
- cache_path = os.path.join(cache_dir, filename)
- if not os.path.exists(cache_path):
- raise EnvironmentError(f"file {cache_path} not found")
-
- meta_path = cache_path + ".json"
- if not os.path.exists(meta_path):
- raise EnvironmentError(f"file {meta_path} not found")
-
- with open(meta_path, encoding="utf-8") as meta_file:
- metadata = json.load(meta_file)
- url = metadata["url"]
- etag = metadata["etag"]
-
- return url, etag
+HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
@@ -215,6 +133,8 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
+ if not os.path.isdir(cache_dir):
+ return []
cached_models = []
for file in os.listdir(cache_dir):
@@ -231,108 +151,6 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
return cached_models
-def cached_path(
- url_or_filename,
- cache_dir=None,
- force_download=False,
- proxies=None,
- resume_download=False,
- user_agent: Union[Dict, str, None] = None,
- extract_compressed_file=False,
- force_extract=False,
- use_auth_token: Union[bool, str, None] = None,
- local_files_only=False,
-) -> Optional[str]:
- """
- Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
- and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
- then return the path
-
- Args:
- cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
- force_download: if True, re-download the file even if it's already cached in the cache dir.
- resume_download: if True, resume the download if incompletely received file is found.
- user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
- use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
- will get token from ~/.huggingface.
- extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
- file in a folder along the archive.
- force_extract: if True when extract_compressed_file is True and the archive was already extracted,
- re-extract the archive and override the folder where it was extracted.
-
- Return:
- Local path (string) of file or if networking is off, last version of file cached on disk.
-
- Raises:
- In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- if isinstance(url_or_filename, Path):
- url_or_filename = str(url_or_filename)
- if isinstance(cache_dir, Path):
- cache_dir = str(cache_dir)
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- if is_remote_url(url_or_filename):
- # URL, so get it from the cache (downloading if necessary)
- output_path = get_from_cache(
- url_or_filename,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- user_agent=user_agent,
- use_auth_token=use_auth_token,
- local_files_only=local_files_only,
- )
- elif os.path.exists(url_or_filename):
- # File, and it exists.
- output_path = url_or_filename
- elif urlparse(url_or_filename).scheme == "":
- # File, but it doesn't exist.
- raise EnvironmentError(f"file {url_or_filename} not found")
- else:
- # Something unknown
- raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
-
- if extract_compressed_file:
- if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
- return output_path
-
- # Path where we extract compressed archives
- # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
- output_dir, output_file = os.path.split(output_path)
- output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
- output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
-
- if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
- return output_path_extracted
-
- # Prevent parallel extractions
- lock_path = output_path + ".lock"
- with FileLock(lock_path):
- shutil.rmtree(output_path_extracted, ignore_errors=True)
- os.makedirs(output_path_extracted)
- if is_zipfile(output_path):
- with ZipFile(output_path, "r") as zip_file:
- zip_file.extractall(output_path_extracted)
- zip_file.close()
- elif tarfile.is_tarfile(output_path):
- tar_file = tarfile.open(output_path)
- tar_file.extractall(output_path_extracted)
- tar_file.close()
- else:
- raise EnvironmentError(f"Archive format of {output_path} could not be identified")
-
- return output_path_extracted
-
- return output_path
-
-
def define_sagemaker_information():
try:
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
@@ -382,223 +200,214 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
-class RepositoryNotFoundError(HTTPError):
+def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
"""
- Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
- not have access to.
+ Explores the cache to return the latest cached file for a given revision.
"""
+ if revision is None:
+ revision = "main"
+ model_id = repo_id.replace("/", "--")
+ model_cache = os.path.join(cache_dir, f"models--{model_id}")
+ if not os.path.isdir(model_cache):
+ # No cache for this model
+ return None
+ for subfolder in ["refs", "snapshots"]:
+ if not os.path.isdir(os.path.join(model_cache, subfolder)):
+ return None
+
+ # Resolve refs (for instance to convert main to the associated commit sha)
+ cached_refs = os.listdir(os.path.join(model_cache, "refs"))
+ if revision in cached_refs:
+ with open(os.path.join(model_cache, "refs", revision)) as f:
+ revision = f.read()
+
+ cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
+ if revision not in cached_shas:
+ # No cache for this revision and we won't try to return a random revision
+ return None
-class EntryNotFoundError(HTTPError):
- """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
-
-
-class RevisionNotFoundError(HTTPError):
- """Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
-
+ cached_file = os.path.join(model_cache, "snapshots", revision, filename)
+ return cached_file if os.path.isfile(cached_file) else None
-def _raise_for_status(request):
- """
- Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
- """
- if "X-Error-Code" in request.headers:
- error_code = request.headers["X-Error-Code"]
- if error_code == "RepoNotFound":
- raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
- elif error_code == "EntryNotFound":
- raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
- elif error_code == "RevisionNotFound":
- raise RevisionNotFoundError((f"404 Client Error: Revision Not Found for url: {request.url}"))
- request.raise_for_status()
+# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
+# future.
+LOCAL_FILES_ONLY_HF_ERROR = (
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
+ "look-ups and downloads online, set 'local_files_only' to False."
+)
-def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
+# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
+# activate/deactivate progress bars.
+@contextmanager
+def _patch_hf_hub_tqdm():
"""
- Download remote file. Do not gobble up errors.
+ A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
+ in logging.
"""
- headers = copy.deepcopy(headers)
- if resume_size > 0:
- headers["Range"] = f"bytes={resume_size}-"
- r = requests.get(url, stream=True, proxies=proxies, headers=headers)
- _raise_for_status(r)
- content_length = r.headers.get("Content-Length")
- total = resume_size + int(content_length) if content_length is not None else None
- # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
- # and can be set using `utils.logging.enable/disable_progress_bar()`
- progress = tqdm(
- unit="B",
- unit_scale=True,
- unit_divisor=1024,
- total=total,
- initial=resume_size,
- desc="Downloading",
- )
- for chunk in r.iter_content(chunk_size=1024):
- if chunk: # filter out keep-alive new chunks
- progress.update(len(chunk))
- temp_file.write(chunk)
- progress.close()
+ old_tqdm = huggingface_hub.file_download.tqdm
+ huggingface_hub.file_download.tqdm = tqdm
+ yield
+ huggingface_hub.file_download.tqdm = old_tqdm
-def get_from_cache(
- url: str,
- cache_dir=None,
- force_download=False,
- proxies=None,
- etag_timeout=10,
- resume_download=False,
- user_agent: Union[Dict, str, None] = None,
- use_auth_token: Union[bool, str, None] = None,
- local_files_only=False,
-) -> Optional[str]:
+def cached_file(
+ path_or_repo_id: Union[str, os.PathLike],
+ filename: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ subfolder: str = "",
+ user_agent: Optional[Union[str, Dict[str, str]]] = None,
+ _raise_exceptions_for_missing_entries=True,
+ _raise_exceptions_for_connection_errors=True,
+):
"""
- Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
- path to the cached file.
+ Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
- Return:
- Local path (string) of file or if networking is off, last version of file cached on disk.
+ Args:
+ path_or_repo_id (`str` or `os.PathLike`):
+ This can be either:
- Raises:
- In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
- """
- if cache_dir is None:
- cache_dir = TRANSFORMERS_CACHE
- if isinstance(cache_dir, Path):
- cache_dir = str(cache_dir)
+ - a string, the *model id* of a model repo on huggingface.co.
+ - a path to a *directory* potentially containing the file.
+ filename (`str`):
+ The name of the file to locate in `path_or_repo`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
- os.makedirs(cache_dir, exist_ok=True)
+
- headers = {"user-agent": http_user_agent(user_agent)}
- if isinstance(use_auth_token, str):
- headers["authorization"] = f"Bearer {use_auth_token}"
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
- headers["authorization"] = f"Bearer {token}"
+ Passing `use_auth_token=True` is required when you want to use a private model.
- url_to_download = url
- etag = None
- if not local_files_only:
- try:
- r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
- _raise_for_status(r)
- etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
- # We favor a custom header indicating the etag of the linked resource, and
- # we fallback to the regular etag header.
- # If we don't have any of those, raise an error.
- if etag is None:
- raise OSError(
- "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
- )
- # In case of a redirect,
- # save an extra redirect on the request.get call,
- # and ensure we download the exact atomic version even if it changed
- # between the HEAD and the GET (unlikely, but hey).
- if 300 <= r.status_code <= 399:
- url_to_download = r.headers["Location"]
- except (
- requests.exceptions.SSLError,
- requests.exceptions.ProxyError,
- RepositoryNotFoundError,
- EntryNotFoundError,
- RevisionNotFoundError,
- ):
- # Actually raise for those subclasses of ConnectionError
- # Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
- raise
- except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
- # Otherwise, our Internet connection is down.
- # etag is None
- pass
-
- filename = url_to_filename(url, etag)
-
- # get cache path to put the file
- cache_path = os.path.join(cache_dir, filename)
-
- # etag is None == we don't have a connection or we passed local_files_only.
- # try to get the last downloaded one
- if etag is None:
- if os.path.exists(cache_path):
- return cache_path
- else:
- matching_files = [
- file
- for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
- if not file.endswith(".json") and not file.endswith(".lock")
- ]
- if len(matching_files) > 0:
- return os.path.join(cache_dir, matching_files[-1])
- else:
- # If files cannot be found and local_files_only=True,
- # the models might've been found if local_files_only=False
- # Notify the user about that
- if local_files_only:
- raise FileNotFoundError(
- "Cannot find the requested files in the cached path and outgoing traffic has been"
- " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
- " to False."
- )
- else:
- raise ValueError(
- "Connection error, and we cannot find the requested files in the cached path."
- " Please try again or make sure your Internet connection is on."
- )
-
- # From now on, etag is not None.
- if os.path.exists(cache_path) and not force_download:
- return cache_path
-
- # Prevent parallel downloads of the same file with a lock.
- lock_path = cache_path + ".lock"
- with FileLock(lock_path):
-
- # If the download just completed while the lock was activated.
- if os.path.exists(cache_path) and not force_download:
- # Even if returning early like here, the lock will be released.
- return cache_path
-
- if resume_download:
- incomplete_path = cache_path + ".incomplete"
-
- @contextmanager
- def _resumable_file_manager() -> "io.BufferedWriter":
- with open(incomplete_path, "ab") as f:
- yield f
-
- temp_file_manager = _resumable_file_manager
- if os.path.exists(incomplete_path):
- resume_size = os.stat(incomplete_path).st_size
- else:
- resume_size = 0
- else:
- temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
- resume_size = 0
+
- # Download to temporary file, then copy to cache dir once finished.
- # Otherwise you get corrupt cache entries if the download gets interrupted.
- with temp_file_manager() as temp_file:
- logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
+ Returns:
+ `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
- http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
+ Examples:
- logger.info(f"storing {url} in cache at {cache_path}")
- os.replace(temp_file.name, cache_path)
+ ```python
+ # Download a model weight from the Hub and cache it.
+ model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
+ ```"""
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+ if subfolder is None:
+ subfolder = ""
+
+ path_or_repo_id = str(path_or_repo_id)
+ full_filename = os.path.join(subfolder, filename)
+ if os.path.isdir(path_or_repo_id):
+ resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
+ if not os.path.isfile(resolved_file):
+ if _raise_exceptions_for_missing_entries:
+ raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
+ else:
+ return None
+ return resolved_file
- # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
- umask = os.umask(0o666)
- os.umask(umask)
- os.chmod(cache_path, 0o666 & ~umask)
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+ user_agent = http_user_agent(user_agent)
+ try:
+ # Load from URL or cache if already cached
+ with _patch_hf_hub_tqdm():
+ resolved_file = hf_hub_download(
+ path_or_repo_id,
+ filename,
+ subfolder=None if len(subfolder) == 0 else subfolder,
+ revision=revision,
+ cache_dir=cache_dir,
+ user_agent=user_agent,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ use_auth_token=use_auth_token,
+ local_files_only=local_files_only,
+ )
- logger.info(f"creating metadata file for {cache_path}")
- meta = {"url": url, "etag": etag}
- meta_path = cache_path + ".json"
- with open(meta_path, "w") as meta_file:
- json.dump(meta, meta_file)
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
+ "pass a token having permission to this repo with `use_auth_token` or log in with "
+ "`huggingface-cli login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
+ "for this model name. Check the model page at "
+ f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
+ )
+ except EntryNotFoundError:
+ if not _raise_exceptions_for_missing_entries:
+ return None
+ if revision is None:
+ revision = "main"
+ raise EnvironmentError(
+ f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
+ f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
+ )
+ except HTTPError as err:
+ # First we try to see if we have a cached version (not up to date):
+ resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
+ if resolved_file is not None:
+ return resolved_file
+ if not _raise_exceptions_for_connection_errors:
+ return None
+
+ raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
+ except ValueError as err:
+ # HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
+ # This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
+ if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
+ return None
+
+ # Otherwise we try to see if we have a cached version (not up to date):
+ resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
+ if resolved_file is not None:
+ return resolved_file
+ if not _raise_exceptions_for_connection_errors:
+ return None
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
+ f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
+ f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
+ )
- return cache_path
+ return resolved_file
def get_file_from_repo(
@@ -611,6 +420,7 @@ def get_file_from_repo(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
+ subfolder: str = "",
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@@ -636,13 +446,16 @@ def get_file_from_repo(
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`).
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
@@ -662,54 +475,26 @@ def get_file_from_repo(
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```"""
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- path_or_repo = str(path_or_repo)
- if os.path.isdir(path_or_repo):
- resolved_file = os.path.join(path_or_repo, filename)
- return resolved_file if os.path.isfile(resolved_file) else None
- else:
- resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
-
- try:
- # Load from URL or cache if already cached
- resolved_file = cached_path(
- resolved_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- )
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{path_or_repo} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
- "pass a token having permission to this repo with `use_auth_token` or log in with "
- "`huggingface-cli login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
- "for this model name. Check the model page at "
- f"'https://huggingface.co/{path_or_repo}' for available revisions."
- )
- except EnvironmentError:
- # The repo and revision exist, but the file does not or there was a connection error fetching it.
- return None
-
- return resolved_file
+ return cached_file(
+ path_or_repo_id=path_or_repo,
+ filename=filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
revision: Optional[str] = None,
- mirror: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
@@ -726,7 +511,7 @@ def has_file(
if os.path.isdir(path_or_repo):
return os.path.isfile(os.path.join(path_or_repo, filename))
- url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
+ url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
@@ -739,7 +524,7 @@ def has_file(
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
- _raise_for_status(r)
+ huggingface_hub.utils._errors._raise_for_status(r)
return True
except RepositoryNotFoundError as e:
logger.error(e)
@@ -755,135 +540,127 @@ def has_file(
return False
-def get_list_of_files(
- path_or_repo: Union[str, os.PathLike],
- revision: Optional[str] = None,
- use_auth_token: Optional[Union[bool, str]] = None,
- local_files_only: bool = False,
-) -> List[str]:
+class PushToHubMixin:
"""
- Gets the list of files inside `path_or_repo`.
-
- Args:
- path_or_repo (`str` or `os.PathLike`):
- Can be either the id of a repo on huggingface.co or a path to a *directory*.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- use_auth_token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`).
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether or not to only rely on local files and not to attempt to download any files.
-
-
-
- This API is not optimized, so calling it a lot may result in connection errors.
-
-
-
- Returns:
- `List[str]`: The list of files available in `path_or_repo`.
+ A Mixin containing the functionality to push a model or tokenizer to the hub.
"""
- path_or_repo = str(path_or_repo)
- # If path_or_repo is a folder, we just return what is inside (subdirectories included).
- if os.path.isdir(path_or_repo):
- list_of_files = []
- for path, dir_names, file_names in os.walk(path_or_repo):
- list_of_files.extend([os.path.join(path, f) for f in file_names])
- return list_of_files
-
- # Can't grab the files if we are on offline mode.
- if is_offline_mode() or local_files_only:
- return []
-
- # Otherwise we grab the token and use the list_repo_files method.
- if isinstance(use_auth_token, str):
- token = use_auth_token
- elif use_auth_token is True:
- token = HfFolder.get_token()
- else:
- token = None
-
- try:
- return list_repo_files(path_or_repo, revision=revision, token=token)
- except HTTPError as e:
- raise ValueError(
- f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
- ) from e
-
-def is_local_clone(repo_path, repo_url):
- """
- Checks if the folder in `repo_path` is a local clone of `repo_url`.
- """
- # First double-check that `repo_path` is a git repo
- if not os.path.exists(os.path.join(repo_path, ".git")):
- return False
- test_git = subprocess.run("git branch".split(), cwd=repo_path)
- if test_git.returncode != 0:
- return False
+ def _create_repo(
+ self,
+ repo_id: str,
+ private: Optional[bool] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ repo_url: Optional[str] = None,
+ organization: Optional[str] = None,
+ ):
+ """
+ Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the
+ token.
+ """
+ if repo_url is not None:
+ warnings.warn(
+ "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
+ "instead."
+ )
+ repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
+ if organization is not None:
+ warnings.warn(
+ "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
+ "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
+ )
+ if not repo_id.startswith(organization):
+ if "/" in repo_id:
+ repo_id = repo_id.split("/")[-1]
+ repo_id = f"{organization}/{repo_id}"
- # Then look at its remotes
- remotes = subprocess.run(
- "git remote -v".split(),
- stderr=subprocess.PIPE,
- stdout=subprocess.PIPE,
- check=True,
- encoding="utf-8",
- cwd=repo_path,
- ).stdout
+ token = HfFolder.get_token() if use_auth_token is True else use_auth_token
+ url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
- return repo_url in remotes.split()
+ # If the namespace is not there, add it or `upload_file` will complain
+ if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
+ repo_id = get_full_repo_name(repo_id, token=token)
+ return repo_id, token
+ def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
+ """
+ Returns the list of files with their last modification timestamp.
+ """
+ return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
-class PushToHubMixin:
- """
- A Mixin containing the functionality to push a model or tokenizer to the hub.
- """
+ def _upload_modified_files(
+ self,
+ working_dir: Union[str, os.PathLike],
+ repo_id: str,
+ files_timestamps: Dict[str, float],
+ commit_message: Optional[str] = None,
+ token: Optional[str] = None,
+ create_pr: bool = False,
+ ):
+ """
+ Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
+ """
+ if commit_message is None:
+ if "Model" in self.__class__.__name__:
+ commit_message = "Upload model"
+ elif "Config" in self.__class__.__name__:
+ commit_message = "Upload config"
+ elif "Tokenizer" in self.__class__.__name__:
+ commit_message = "Upload tokenizer"
+ elif "FeatureExtractor" in self.__class__.__name__:
+ commit_message = "Upload feature extractor"
+ elif "Processor" in self.__class__.__name__:
+ commit_message = "Upload processor"
+ else:
+ commit_message = f"Upload {self.__class__.__name__}"
+ modified_files = [
+ f
+ for f in os.listdir(working_dir)
+ if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
+ ]
+ operations = []
+ for file in modified_files:
+ operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file))
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
+ return create_commit(
+ repo_id=repo_id, operations=operations, commit_message=commit_message, token=token, create_pr=create_pr
+ )
def push_to_hub(
self,
- repo_path_or_name: Optional[str] = None,
- repo_url: Optional[str] = None,
- use_temp_dir: bool = False,
+ repo_id: str,
+ use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = None,
- organization: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
- **model_card_kwargs
+ max_shard_size: Optional[Union[int, str]] = "10GB",
+ create_pr: bool = False,
+ **deprecated_kwargs
) -> str:
"""
Upload the {object_files} to the š¤ Model Hub while synchronizing a local clone of the repo in
`repo_path_or_name`.
Parameters:
- repo_path_or_name (`str`, *optional*):
- Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
- the repository will have the name of that local folder). If not specified, will default to the name
- given by `repo_url` and a local directory with that name will be created.
- repo_url (`str`, *optional*):
- Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
- repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
- use_temp_dir (`bool`, *optional*, defaults to `False`):
- Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
- current working directory. This will slow things down if you are making changes in an existing repo
- since you will need to clone the repo before every push.
+ repo_id (`str`):
+ The name of the repository you want to push your {object} to. It should contain your organization name
+ when pushing to a given organization.
+ use_temp_dir (`bool`, *optional*):
+ Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
+ Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
commit_message (`str`, *optional*):
- Message to commit while pushing. Will default to `"add {object}"`.
- organization (`str`, *optional*):
- Organization in which you want to push your {object} (you must be a member of this organization).
+ Message to commit while pushing. Will default to `"Upload {object}"`.
private (`bool`, *optional*):
Whether or not the repository created should be private (requires a paying subscription).
use_auth_token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
- `repo_url` is not specified.
-
-
- Returns:
- `str`: The url of the commit of your {object} in the given repository.
+ when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
+ is not specified.
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
+ will then be each of size lower than this size. If expressed as a string, needs to be digits followed
+ by a unit (like `"5MB"`).
+ create_pr (`bool`, *optional*, defaults to `False`):
+ Whether or not to create a PR with the uploaded files or directly commit.
Examples:
@@ -892,133 +669,46 @@ def push_to_hub(
{object} = {object_class}.from_pretrained("bert-base-cased")
- # Push the {object} to your namespace with the name "my-finetuned-bert" and have a local clone in the
- # *my-finetuned-bert* folder.
+ # Push the {object} to your namespace with the name "my-finetuned-bert".
{object}.push_to_hub("my-finetuned-bert")
- # Push the {object} to your namespace with the name "my-finetuned-bert" with no local clone.
- {object}.push_to_hub("my-finetuned-bert", use_temp_dir=True)
-
- # Push the {object} to an organization with the name "my-finetuned-bert" and have a local clone in the
- # *my-finetuned-bert* folder.
- {object}.push_to_hub("my-finetuned-bert", organization="huggingface")
-
- # Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
- {object}.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
+ # Push the {object} to an organization with the name "my-finetuned-bert".
+ {object}.push_to_hub("huggingface/my-finetuned-bert")
```
"""
- if use_temp_dir:
- # Make sure we use the right `repo_name` for the `repo_url` before replacing it.
- if repo_url is None:
- if use_auth_token is None:
- use_auth_token = True
- repo_name = Path(repo_path_or_name).name
- repo_url = self._get_repo_url_from_name(
- repo_name, organization=organization, private=private, use_auth_token=use_auth_token
- )
- repo_path_or_name = tempfile.mkdtemp()
-
- # Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
- repo = self._create_or_get_repo(
- repo_path_or_name=repo_path_or_name,
- repo_url=repo_url,
- organization=organization,
- private=private,
- use_auth_token=use_auth_token,
- )
- # Save the files in the cloned repo
- self.save_pretrained(repo_path_or_name)
- if hasattr(self, "history") and hasattr(self, "create_model_card"):
- # This is a Keras model and we might be able to fish out its History and make a model card out of it
- base_model_card_args = {
- "output_dir": repo_path_or_name,
- "model_name": Path(repo_path_or_name).name,
- }
- base_model_card_args.update(model_card_kwargs)
- self.create_model_card(**base_model_card_args)
- # Commit and push!
- url = self._push_to_hub(repo, commit_message=commit_message)
-
- # Clean up! Clean up! Everybody everywhere!
- if use_temp_dir:
- shutil.rmtree(repo_path_or_name)
-
- return url
-
- @staticmethod
- def _get_repo_url_from_name(
- repo_name: str,
- organization: Optional[str] = None,
- private: bool = None,
- use_auth_token: Optional[Union[bool, str]] = None,
- ) -> str:
- if isinstance(use_auth_token, str):
- token = use_auth_token
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise ValueError(
- "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
- "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
- "token as the `use_auth_token` argument."
- )
+ if "repo_path_or_name" in deprecated_kwargs:
+ warnings.warn(
+ "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
+ "`repo_id` instead."
+ )
+ repo_id = deprecated_kwargs.pop("repo_path_or_name")
+ # Deprecation warning will be sent after for repo_url and organization
+ repo_url = deprecated_kwargs.pop("repo_url", None)
+ organization = deprecated_kwargs.pop("organization", None)
+
+ if os.path.isdir(repo_id):
+ working_dir = repo_id
+ repo_id = repo_id.split(os.path.sep)[-1]
else:
- token = None
-
- # Special provision for the test endpoint (CI)
- return create_repo(
- token,
- repo_name,
- organization=organization,
- private=private,
- repo_type=None,
- exist_ok=True,
+ working_dir = repo_id.split("/")[-1]
+
+ repo_id, token = self._create_repo(
+ repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
)
- @classmethod
- def _create_or_get_repo(
- cls,
- repo_path_or_name: Optional[str] = None,
- repo_url: Optional[str] = None,
- organization: Optional[str] = None,
- private: bool = None,
- use_auth_token: Optional[Union[bool, str]] = None,
- ) -> Repository:
- if repo_path_or_name is None and repo_url is None:
- raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
+ if use_temp_dir is None:
+ use_temp_dir = not os.path.isdir(working_dir)
- if use_auth_token is None and repo_url is None:
- use_auth_token = True
+ with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
+ files_timestamps = self._get_files_timestamps(work_dir)
- if repo_path_or_name is None:
- repo_path_or_name = repo_url.split("/")[-1]
+ # Save all files.
+ self.save_pretrained(work_dir, max_shard_size=max_shard_size)
- if repo_url is None and not os.path.exists(repo_path_or_name):
- repo_name = Path(repo_path_or_name).name
- repo_url = cls._get_repo_url_from_name(
- repo_name, organization=organization, private=private, use_auth_token=use_auth_token
+ return self._upload_modified_files(
+ work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr
)
- # Create a working directory if it does not exist.
- if not os.path.exists(repo_path_or_name):
- os.makedirs(repo_path_or_name)
-
- repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
- repo.git_pull()
- return repo
-
- @classmethod
- def _push_to_hub(cls, repo: Repository, commit_message: Optional[str] = None) -> str:
- if commit_message is None:
- if "Tokenizer" in cls.__name__:
- commit_message = "add tokenizer"
- elif "Config" in cls.__name__:
- commit_message = "add config"
- else:
- commit_message = "add model"
-
- return repo.push_to_hub(commit_message=commit_message)
-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
@@ -1028,3 +718,338 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
+
+
+def send_example_telemetry(example_name, *example_args, framework="pytorch"):
+ """
+ Sends telemetry that helps tracking the examples use.
+
+ Args:
+ example_name (`str`): The name of the example.
+ *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
+ try to extract the model and dataset name from those. Nothing else is tracked.
+ framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
+ """
+ if is_offline_mode():
+ return
+
+ data = {"example": example_name, "framework": framework}
+ for args in example_args:
+ args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
+ if "model_name_or_path" in args_as_dict:
+ model_name = args_as_dict["model_name_or_path"]
+ # Filter out local paths
+ if not os.path.isdir(model_name):
+ data["model_name"] = args_as_dict["model_name_or_path"]
+ if "dataset_name" in args_as_dict:
+ data["dataset_name"] = args_as_dict["dataset_name"]
+ elif "task_name" in args_as_dict:
+ # Extract script name from the example_name
+ script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
+ script_name = script_name.replace("_no_trainer", "")
+ data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
+
+ headers = {"user-agent": http_user_agent(data)}
+ try:
+ r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
+ r.raise_for_status()
+ except Exception:
+ # We don't want to error in case of connection errors of any kind.
+ pass
+
+
+def convert_file_size_to_int(size: Union[int, str]):
+ """
+ Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
+
+ Args:
+ size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
+
+ Example:
+ ```py
+ >>> convert_file_size_to_int("1MiB")
+ 1048576
+ ```
+ """
+ if isinstance(size, int):
+ return size
+ if size.upper().endswith("GIB"):
+ return int(size[:-3]) * (2**30)
+ if size.upper().endswith("MIB"):
+ return int(size[:-3]) * (2**20)
+ if size.upper().endswith("KIB"):
+ return int(size[:-3]) * (2**10)
+ if size.upper().endswith("GB"):
+ int_size = int(size[:-2]) * (10**9)
+ return int_size // 8 if size.endswith("b") else int_size
+ if size.upper().endswith("MB"):
+ int_size = int(size[:-2]) * (10**6)
+ return int_size // 8 if size.endswith("b") else int_size
+ if size.upper().endswith("KB"):
+ int_size = int(size[:-2]) * (10**3)
+ return int_size // 8 if size.endswith("b") else int_size
+ raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
+
+
+def get_checkpoint_shard_files(
+ pretrained_model_name_or_path,
+ index_filename,
+ cache_dir=None,
+ force_download=False,
+ proxies=None,
+ resume_download=False,
+ local_files_only=False,
+ use_auth_token=None,
+ user_agent=None,
+ revision=None,
+ subfolder="",
+):
+ """
+ For a given model:
+
+ - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
+ Hub
+ - returns the list of paths to all the shards, as well as some metadata.
+
+ For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
+ index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
+ """
+ import json
+
+ if not os.path.isfile(index_filename):
+ raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
+
+ with open(index_filename, "r") as f:
+ index = json.loads(f.read())
+
+ shard_filenames = sorted(list(set(index["weight_map"].values())))
+ sharded_metadata = index["metadata"]
+ sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
+
+ # First, let's deal with local folder.
+ if os.path.isdir(pretrained_model_name_or_path):
+ shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
+ return shard_filenames, sharded_metadata
+
+ # At this stage pretrained_model_name_or_path is a model identifier on the Hub
+ cached_filenames = []
+ for shard_filename in shard_filenames:
+ try:
+ # Load from URL
+ cached_filename = cached_file(
+ pretrained_model_name_or_path,
+ shard_filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ )
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
+ # we don't have to catch them here.
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
+ "required according to the checkpoint index."
+ )
+ except HTTPError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
+ " again after checking your internet connection."
+ )
+
+ cached_filenames.append(cached_filename)
+
+ return cached_filenames, sharded_metadata
+
+
+# All what is below is for conversion between old cache format and new cache format.
+
+
+def get_all_cached_files(cache_dir=None):
+ """
+ Returns a list for all files cached with appropriate metadata.
+ """
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ else:
+ cache_dir = str(cache_dir)
+ if not os.path.isdir(cache_dir):
+ return []
+
+ cached_files = []
+ for file in os.listdir(cache_dir):
+ meta_path = os.path.join(cache_dir, f"{file}.json")
+ if not os.path.isfile(meta_path):
+ continue
+
+ with open(meta_path, encoding="utf-8") as meta_file:
+ metadata = json.load(meta_file)
+ url = metadata["url"]
+ etag = metadata["etag"].replace('"', "")
+ cached_files.append({"file": file, "url": url, "etag": etag})
+
+ return cached_files
+
+
+def get_hub_metadata(url, token=None):
+ """
+ Returns the commit hash and associated etag for a given url.
+ """
+ if token is None:
+ token = HfFolder.get_token()
+ headers = {"user-agent": http_user_agent()}
+ headers["authorization"] = f"Bearer {token}"
+
+ r = huggingface_hub.file_download._request_with_retry(
+ method="HEAD", url=url, headers=headers, allow_redirects=False
+ )
+ huggingface_hub.file_download._raise_for_status(r)
+ commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
+ etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
+ if etag is not None:
+ etag = huggingface_hub.file_download._normalize_etag(etag)
+ return etag, commit_hash
+
+
+def extract_info_from_url(url):
+ """
+ Extract repo_name, revision and filename from an url.
+ """
+ search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
+ if search is None:
+ return None
+ repo, revision, filename = search.groups()
+ cache_repo = "--".join(["models"] + repo.split("/"))
+ return {"repo": cache_repo, "revision": revision, "filename": filename}
+
+
+def clean_files_for(file):
+ """
+ Remove, if they exist, file, file.json and file.lock
+ """
+ for f in [file, f"{file}.json", f"{file}.lock"]:
+ if os.path.isfile(f):
+ os.remove(f)
+
+
+def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
+ """
+ Move file to repo following the new huggingface hub cache organization.
+ """
+ os.makedirs(repo, exist_ok=True)
+
+ # refs
+ os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
+ if revision != commit_hash:
+ ref_path = os.path.join(repo, "refs", revision)
+ with open(ref_path, "w") as f:
+ f.write(commit_hash)
+
+ # blobs
+ os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
+ blob_path = os.path.join(repo, "blobs", etag)
+ shutil.move(file, blob_path)
+
+ # snapshots
+ os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
+ os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
+ pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
+ huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
+ clean_files_for(file)
+
+
+def move_cache(cache_dir=None, new_cache_dir=None, token=None):
+ if new_cache_dir is None:
+ new_cache_dir = TRANSFORMERS_CACHE
+ if cache_dir is None:
+ # Migrate from old cache in .cache/huggingface/hub
+ old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
+ if os.path.isdir(str(old_cache)):
+ cache_dir = str(old_cache)
+ else:
+ cache_dir = new_cache_dir
+ if token is None:
+ token = HfFolder.get_token()
+ cached_files = get_all_cached_files(cache_dir=cache_dir)
+ print(f"Moving {len(cached_files)} files to the new cache system")
+
+ hub_metadata = {}
+ for file_info in tqdm(cached_files):
+ url = file_info.pop("url")
+ if url not in hub_metadata:
+ try:
+ hub_metadata[url] = get_hub_metadata(url, token=token)
+ except requests.HTTPError:
+ continue
+
+ etag, commit_hash = hub_metadata[url]
+ if etag is None or commit_hash is None:
+ continue
+
+ if file_info["etag"] != etag:
+ # Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
+ clean_files_for(os.path.join(cache_dir, file_info["file"]))
+ continue
+
+ url_info = extract_info_from_url(url)
+ if url_info is None:
+ # Not a file from huggingface.co
+ continue
+
+ repo = os.path.join(new_cache_dir, url_info["repo"])
+ move_to_new_cache(
+ file=os.path.join(cache_dir, file_info["file"]),
+ repo=repo,
+ filename=url_info["filename"],
+ revision=url_info["revision"],
+ etag=etag,
+ commit_hash=commit_hash,
+ )
+
+
+cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
+if not os.path.isfile(cache_version_file):
+ cache_version = 0
+else:
+ with open(cache_version_file) as f:
+ cache_version = int(f.read())
+
+
+if cache_version < 1:
+ if is_offline_mode():
+ logger.warn(
+ "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
+ "cache seems to be the one of a previous version. It is very likely that all your calls to any "
+ "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
+ "your cache be updated automatically, then you can go back to offline mode."
+ )
+ else:
+ logger.warn(
+ "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
+ "one-time only operation. You can interrupt this and resume the migration later on by calling "
+ "`transformers.utils.move_cache()`."
+ )
+ try:
+ move_cache()
+ except Exception as e:
+ trace = "\n".join(traceback.format_tb(e.__traceback__))
+ logger.error(
+ f"There was a problem when trying to move your cache:\n\n{trace}\n\nPlease file an issue at "
+ "https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole message and we "
+ "will do our best to help."
+ )
+
+ try:
+ os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
+ with open(cache_version_file, "w") as f:
+ f.write("1")
+ except Exception:
+ logger.warn(
+ f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
+ "the environment variable TRANSFORMERS_CACHE to a writable directory."
+ )
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 505ba94e0b19..37172d14fcc2 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -19,6 +19,7 @@
import json
import os
import sys
+import warnings
from collections import OrderedDict
from functools import wraps
from itertools import chain
@@ -70,6 +71,7 @@
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
+ "tensorflow-aarch64",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
@@ -245,6 +247,16 @@
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
+ccl_version = "N/A"
+_is_ccl_available = (
+ importlib.util.find_spec("torch_ccl") is not None
+ or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
+)
+try:
+ ccl_version = importlib_metadata.version("oneccl_bind_pt")
+ logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}")
+except importlib_metadata.PackageNotFoundError:
+ _is_ccl_available = False
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
@@ -272,7 +284,7 @@ def is_torch_cuda_available():
return False
-def is_torch_bf16_available():
+def is_torch_bf16_gpu_available():
if not is_torch_available():
return False
@@ -282,27 +294,57 @@ def is_torch_bf16_available():
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
- # 1. the hardware needs to support bf16 (arch >= Ampere)
- # 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
- # 3. CUDA >= 11
+ # 1. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
+ # 2. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU)
+ # 3. if using gpu, CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
-
- if not torch.cuda.is_available() or torch.version.cuda is None:
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
return False
- if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
+
+ if torch.cuda.is_available() and torch.version.cuda is not None:
+ if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
+ return False
+ if int(torch.version.cuda.split(".")[0]) < 11:
+ return False
+ if not hasattr(torch.cuda.amp, "autocast"):
+ return False
+ else:
return False
- if int(torch.version.cuda.split(".")[0]) < 11:
+
+ return True
+
+
+def is_torch_bf16_cpu_available():
+ if not is_torch_available():
return False
- if version.parse(torch.__version__) < version.parse("1.10"):
+
+ import torch
+
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
return False
- if not hasattr(torch, "autocast"):
+
+ try:
+ # multiple levels of AttributeError depending on the pytorch version so do them all in one check
+ _ = torch.cpu.amp.autocast
+ except AttributeError:
return False
return True
+def is_torch_bf16_available():
+ # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
+ # has become ambiguous and therefore deprecated
+ warnings.warn(
+ "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
+ "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
+ FutureWarning,
+ )
+ return is_torch_bf16_gpu_available()
+
+
def is_torch_tf32_available():
if not is_torch_available():
return False
@@ -315,7 +357,7 @@ def is_torch_tf32_available():
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
- if version.parse(torch.__version__) < version.parse("1.7"):
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
return False
return True
@@ -325,7 +367,7 @@ def is_torch_tf32_available():
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
- _torch_fx_available = (torch_version.major, torch_version.minor) == (
+ _torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
@@ -365,15 +407,32 @@ def is_ftfy_available():
return _ftfy_available
-def is_torch_tpu_available():
+def is_torch_tpu_available(check_device=True):
+ "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if not _torch_available:
return False
- # This test is probably enough, but just in case, we unpack a bit.
- if importlib.util.find_spec("torch_xla") is None:
- return False
- if importlib.util.find_spec("torch_xla.core") is None:
+ if importlib.util.find_spec("torch_xla") is not None:
+ if check_device:
+ # We need to check if `xla_device` can be found, will raise a RuntimeError if not
+ try:
+ import torch_xla.core.xla_model as xm
+
+ _ = xm.xla_device()
+ return True
+ except RuntimeError:
+ return False
+ return True
+ return False
+
+
+def is_torchdynamo_available():
+ return importlib.util.find_spec("torchdynamo") is not None
+
+
+def is_torch_tensorrt_fx_available():
+ if importlib.util.find_spec("torch_tensorrt") is None:
return False
- return importlib.util.find_spec("torch_xla.core.xla_model") is not None
+ return importlib.util.find_spec("torch_tensorrt.fx") is not None
def is_datasets_available():
@@ -396,10 +455,36 @@ def is_py3nvml_available():
return importlib.util.find_spec("py3nvml") is not None
+def is_sacremoses_available():
+ return importlib.util.find_spec("sacremoses") is not None
+
+
def is_apex_available():
return importlib.util.find_spec("apex") is not None
+def is_ipex_available():
+ def get_major_and_minor_from_version(full_version):
+ return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
+
+ if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
+ return False
+ _ipex_version = "N/A"
+ try:
+ _ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
+ except importlib_metadata.PackageNotFoundError:
+ return False
+ torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
+ ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
+ if torch_major_and_minor != ipex_major_and_minor:
+ logger.warning(
+ f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
+ f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
+ )
+ return False
+ return True
+
+
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
@@ -428,6 +513,10 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None
+def is_accelerate_available():
+ return importlib.util.find_spec("accelerate") is not None
+
+
def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None
@@ -444,6 +533,10 @@ def is_spacy_available():
return importlib.util.find_spec("spacy") is not None
+def is_tensorflow_text_available():
+ return importlib.util.find_spec("tensorflow_text") is not None
+
+
def is_in_notebook():
try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
@@ -452,6 +545,10 @@ def is_in_notebook():
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
+ # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
+ # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
+ raise ImportError("databricks")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
@@ -550,6 +647,10 @@ def wrapper(*args, **kwargs):
return wrapper
+def is_ccl_available():
+ return _is_ccl_available
+
+
# docstyle-ignore
DATASETS_IMPORT_ERROR = """
{0} requires the š¤ Datasets library but it was not found in your environment. You can install it with:
@@ -611,6 +712,30 @@ def wrapper(*args, **kwargs):
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR_WITH_TF = """
+{0} requires the PyTorch library but it was not found in your environment.
+However, we were able to find a TensorFlow installation. TensorFlow classes begin
+with "TF", but are otherwise identically named to our PyTorch classes. This
+means that the TF equivalent of the class you tried to import would be "TF{0}".
+If you want to use TensorFlow, please use TF classes instead!
+
+If you really do want to use PyTorch please go to
+https://pytorch.org/get-started/locally/ and follow the instructions that
+match your environment.
+"""
+
+# docstyle-ignore
+TF_IMPORT_ERROR_WITH_PYTORCH = """
+{0} requires the TensorFlow library but it was not found in your environment.
+However, we were able to find a PyTorch installation. PyTorch classes do not begin
+with "TF", but are otherwise identically named to our TF classes.
+If you want to use PyTorch, please use those classes instead!
+
+If you really do want to use TensorFlow, please follow the instructions on the
+installation page https://www.tensorflow.org/install that match your environment.
+"""
+
# docstyle-ignore
SKLEARN_IMPORT_ERROR = """
@@ -672,6 +797,12 @@ def wrapper(*args, **kwargs):
explained here: https://github.com/tensorflow/probability.
"""
+# docstyle-ignore
+TENSORFLOW_TEXT_IMPORT_ERROR = """
+{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
+explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
+"""
+
# docstyle-ignore
PANDAS_IMPORT_ERROR = """
@@ -687,6 +818,13 @@ def wrapper(*args, **kwargs):
"""
+# docstyle-ignore
+SACREMOSES_IMPORT_ERROR = """
+{0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
+`pip install sacremoses`
+"""
+
+
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
@@ -725,6 +863,17 @@ def wrapper(*args, **kwargs):
`pip install pyctcdecode`
"""
+# docstyle-ignore
+ACCELERATE_IMPORT_ERROR = """
+{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
+`pip install accelerate`
+"""
+
+# docstyle-ignore
+CCL_IMPORT_ERROR = """
+{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
+`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
+"""
BACKENDS_MAPPING = OrderedDict(
[
@@ -738,6 +887,7 @@ def wrapper(*args, **kwargs):
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
+ ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
@@ -745,11 +895,14 @@ def wrapper(*args, **kwargs):
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
+ ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
+ ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
]
)
@@ -759,6 +912,15 @@ def requires_backends(obj, backends):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+
+ # Raise an error for users who might not realize that classes without "TF" are torch-only
+ if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
+ raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
+
+ # Raise the inverse error for PyTorch users trying to load TF classes
+ if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
+ raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
+
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
@@ -861,8 +1023,13 @@ def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
- f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}"
+ f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
+ f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
+
+
+class OptionalDependencyNotAvailable(BaseException):
+ """Internally used error class for signalling an optional dependency was not found."""
diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py
index abddd6c60fac..bcbe80801359 100644
--- a/src/transformers/utils/model_parallel_utils.py
+++ b/src/transformers/utils/model_parallel_utils.py
@@ -32,13 +32,15 @@ def assert_device_map(device_map, num_blocks):
if len(duplicate_blocks) != 0:
raise ValueError(
- "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
- "attention blocks were specified more than once: " + str(duplicate_blocks)
+ "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device."
+ " These attention blocks were specified more than once: "
+ + str(duplicate_blocks)
)
if len(missing_blocks) != 0:
raise ValueError(
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
- "blocks to a device on the device_map: " + str(missing_blocks)
+ "blocks to a device on the device_map: "
+ + str(missing_blocks)
)
if len(extra_blocks) != 0:
raise ValueError(
diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py
index 0ffbdc8deecf..8d81d76c4fd1 100644
--- a/src/transformers/utils/notebook.py
+++ b/src/transformers/utils/notebook.py
@@ -174,7 +174,10 @@ def update_bar(self, value, comment=None):
elif self.predicted_remaining is None:
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)}"
else:
- self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} < {format_time(self.predicted_remaining)}"
+ self.label = (
+ f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} <"
+ f" {format_time(self.predicted_remaining)}"
+ )
self.label += f", {1/self.average_time_per_item:.2f} it/s"
self.label += "]" if self.comment is None or len(self.comment) == 0 else f", {self.comment}]"
self.display()
@@ -304,6 +307,11 @@ def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwarg
else:
self.prediction_bar.update(self.prediction_bar.value + 1)
+ def on_predict(self, args, state, control, **kwargs):
+ if self.prediction_bar is not None:
+ self.prediction_bar.close()
+ self.prediction_bar = None
+
def on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
diff --git a/src/transformers/utils/sentencepiece_model_pb2.py b/src/transformers/utils/sentencepiece_model_pb2.py
index 5d52b365caab..41411cee8cd6 100644
--- a/src/transformers/utils/sentencepiece_model_pb2.py
+++ b/src/transformers/utils/sentencepiece_model_pb2.py
@@ -32,7 +32,53 @@
syntax="proto2",
serialized_options=b"H\003",
create_key=_descriptor._internal_create_key,
- serialized_pb=b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03',
+ serialized_pb=(
+ b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01'
+ b" \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02"
+ b" \x01(\t\x12\x41\n\nmodel_type\x18\x03"
+ b" \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04"
+ b" \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12"
+ b' \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n'
+ b" \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b"
+ b" \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12"
+ b' \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r'
+ b" \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e"
+ b" \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f"
+ b" \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12"
+ b" \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10"
+ b" \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11"
+ b" \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14"
+ b" \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15"
+ b" \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17"
+ b" \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16"
+ b" \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18"
+ b" \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19"
+ b" \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e"
+ b" \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$"
+ b" \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18"
+ b' \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18"'
+ b" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18)"
+ b" \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+"
+ b" \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18."
+ b" \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30"
+ b" \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87"
+ b" \x12+\n\x1ctrain_extremely_large_corpus\x18\x31"
+ b' \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01'
+ b" \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03"
+ b" \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12"
+ b" \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06"
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01'
+ b' \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01'
+ b" \x01(\t\x12\x10\n\x08\x65xpected\x18\x02"
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01'
+ b" \x03(\x0b\x32'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02"
+ b" \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03"
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04"
+ b" \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05"
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01"
+ b" \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03"
+ b' \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
+ ),
)
diff --git a/src/transformers/utils/versions.py b/src/transformers/utils/versions.py
index 26a160f1fd6e..14db9b55e597 100644
--- a/src/transformers/utils/versions.py
+++ b/src/transformers/utils/versions.py
@@ -77,7 +77,8 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
if not match:
raise ValueError(
- f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
+ f" got {requirement}"
)
pkg, want_full = match[0]
want_range = want_full.split(",") # there could be multiple requirements
@@ -86,7 +87,8 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
if not match:
raise ValueError(
- f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
+ f" but got {requirement}"
)
op, want_ver = match[0]
wanted[op] = want_ver
diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
index 5a641f85f2ef..e7a622edd715 100755
--- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
@@ -46,6 +46,7 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
@@ -117,7 +118,7 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "help": "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
},
)
@@ -207,6 +208,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_{{cookiecutter.example_shortcut}}", model_args, data_args)
+
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
@@ -519,6 +524,7 @@ def _mp_fn(index):
get_scheduler,
set_seed,
)
+from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
@@ -662,6 +668,10 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_{{cookiecutter.example_shortcut}", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# Make one log on every process with the configuration for debugging.
diff --git a/templates/adding_a_new_model/ADD_NEW_MODEL_PROPOSAL_TEMPLATE.md b/templates/adding_a_new_model/ADD_NEW_MODEL_PROPOSAL_TEMPLATE.md
index 3b2de6f3c098..2066356470fb 100644
--- a/templates/adding_a_new_model/ADD_NEW_MODEL_PROPOSAL_TEMPLATE.md
+++ b/templates/adding_a_new_model/ADD_NEW_MODEL_PROPOSAL_TEMPLATE.md
@@ -990,7 +990,7 @@ tokenizer.
For [camelcase name of model], the tokenizer files can be found here:
- [To be filled out by mentor]
-and having implemented the š¤Transformers' version of the tokenizer can be loaded as follows:
+and having implemented the š¤ Transformers' version of the tokenizer can be loaded as follows:
[To be filled out by mentor]
diff --git a/templates/adding_a_new_model/README.md b/templates/adding_a_new_model/README.md
index 496c4f004be5..4bb6663937ce 100644
--- a/templates/adding_a_new_model/README.md
+++ b/templates/adding_a_new_model/README.md
@@ -222,7 +222,7 @@ You will also see a doc file and tests for your new models. First you should run
```
make style
-maxke fix-copies
+make fix-copies
```
and then you can start tweaking your model. You should:
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
index afcfeb87eb77..0d05ee406add 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
@@ -18,15 +18,23 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available
+from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
+
+
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_tf_available
+
+
{% endif %}
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_torch_available
+
+
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_flax_available
+
+
{% endif %}
_import_structure = {
@@ -34,12 +42,22 @@
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
@@ -54,7 +72,12 @@
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
{% else %}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
@@ -70,7 +93,12 @@
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
@@ -84,7 +112,12 @@
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
@@ -96,7 +129,12 @@
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"Flax{{cookiecutter.camelcase_modelname}}ForCausalLM",
@@ -109,7 +147,12 @@
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
@@ -125,12 +168,22 @@
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
@@ -145,7 +198,12 @@
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% else %}
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
@@ -159,7 +217,12 @@
{% endif %}
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
@@ -173,7 +236,12 @@
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
@@ -183,7 +251,12 @@
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
@@ -196,7 +269,12 @@
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
index 43fbad249518..676270c131fb 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
@@ -25,6 +25,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
from flax.linen import combine_masks, make_causal_mask
+from flax.linen import partitioning as nn_partitioning
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
@@ -126,6 +127,8 @@
"""
+remat = nn_partitioning.remat
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
@@ -507,11 +510,19 @@ def __call__(
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
def setup(self):
- self.layers = [
- Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
- ]
+ if self.gradient_checkpointing:
+ Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
+ self.layers = [
+ Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ else:
+ self.layers = [
+ Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
+ ]
def __call__(
self,
@@ -545,12 +556,12 @@ def __call__(
layer_outputs = layer(
hidden_states,
attention_mask,
- layer_head_mask=head_mask[i] if head_mask is not None else None,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
+ head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ init_cache,
+ deterministic,
+ output_attentions,
)
hidden_states = layer_outputs[0]
@@ -581,9 +592,10 @@ def __call__(
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
def setup(self):
- self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype)
+ self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
def __call__(
self,
@@ -725,11 +737,20 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
+ gradient_checkpointing: bool = False,
**kwargs
):
- module = self.module_class(config=config, dtype=dtype, **kwargs)
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
+ def enable_gradient_checkpointing(self):
+ self._module = self.module_class(
+ config=self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=True,
+ )
+
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
+ gradient_checkpointing: bool = False
def setup(self):
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
- self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype)
+ self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
def __call__(
@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
@@ -1030,9 +1053,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM(Flax{{cookiecutter.cam
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
@@ -1092,9 +1116,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.cam
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(
self.config.num_labels,
@@ -1163,9 +1188,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(Flax{{co
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
@@ -1238,9 +1264,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice(Flax{{cookiecutt
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
@@ -1302,9 +1329,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification(Flax{{cooki
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
@@ -1373,9 +1401,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiec
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
def setup(self):
- self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
@@ -1996,7 +2025,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
@@ -2997,10 +3026,10 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
```python
>>> import jax
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
-
+
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
-
+
>>> TXT = "My friends are but they eat too many carbs."
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
index f5c40b27d617..487b7c4461b1 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
@@ -1716,7 +1716,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -2821,7 +2821,7 @@ def __init__(self, config, *inputs, **kwargs):
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(inputs=self.serving.input_signature)
self.use_cache = config.use_cache
- # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
index 938cbea65c63..cbe8153c0ec7 100755
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
@@ -22,7 +22,6 @@
import torch
import torch.utils.checkpoint
-from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union
@@ -48,6 +47,7 @@
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
+ is_torch_greater_than_1_6,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -157,7 +157,7 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- if version.parse(torch.__version__) > version.parse("1.6.0"):
+ if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
@@ -1632,7 +1632,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -2100,7 +2100,7 @@ def _set_gradient_checkpointing(self, module, value=False):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
also be used by default.
- If you want to change padding behavior, you should read [`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs`] and
+ If you want to change padding behavior, you should read [`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_attention_mask`] and
modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
@@ -2136,7 +2136,7 @@ def _set_gradient_checkpointing(self, module, value=False):
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
- instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated
+ instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
@@ -2483,7 +2483,7 @@ def forward(
If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of
- shape `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size,
+ shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size,
sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
index c95b82115dc3..273adca0ef23 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
@@ -115,7 +115,7 @@
{% endif -%}
# End.
-# Below: " # Fast tokenizers"
+# Below: " # Fast tokenizers structure"
# Replace with:
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
# End.
@@ -126,7 +126,7 @@
# End.
# To replace in: "src/transformers/__init__.py"
-# Below: " if is_torch_available():" if generating PyTorch
+# Below: " # PyTorch model imports" if generating PyTorch
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -155,7 +155,7 @@
{% endif -%}
# End.
-# Below: " if is_tf_available():" if generating TensorFlow
+# Below: " # TensorFlow model imports" if generating TensorFlow
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -179,7 +179,7 @@
{% endif -%}
# End.
-# Below: " if is_flax_available():" if generating Flax
+# Below: " # Flax model imports" if generating Flax
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -204,7 +204,7 @@
{% endif -%}
# End.
-# Below: " if is_tokenizers_available():"
+# Below: " # Fast tokenizers imports"
# Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
# End.
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py
index a3ad1dd7c9ff..a9c072f977d2 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/tokenization_{{cookiecutter.lowercase_modelname}}.py
@@ -144,14 +144,14 @@ def __init__(
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
- "Initialisation"
+ """ Initialisation """
@property
def vocab_size(self):
- "Returns vocab size"
+ """ Returns vocab size """
def get_vocab(self):
- "Returns vocab as a dict"
+ """ Returns vocab as a dict """
def _tokenize(self, text):
""" Returns a tokenized string. """
diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py
index 9fba62815b01..65644ef3ac6d 100644
--- a/tests/deepspeed/test_deepspeed.py
+++ b/tests/deepspeed/test_deepspeed.py
@@ -20,10 +20,12 @@
import unittest
from copy import deepcopy
+import datasets
+
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
-from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
+from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
from transformers.testing_utils import (
CaptureLogger,
CaptureStd,
@@ -40,7 +42,7 @@
slow,
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
-from transformers.utils import WEIGHTS_NAME, is_torch_bf16_available
+from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
if is_torch_available():
@@ -127,7 +129,7 @@ def get_launcher(distributed=False):
BF16 = "bf16"
stages = [ZERO2, ZERO3]
-if is_torch_bf16_available():
+if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
else:
dtypes = [FP16]
@@ -159,6 +161,12 @@ def setUp(self):
MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
+ def tearDown(self):
+ super().tearDown()
+
+ # reset the ds config global so that tests state doesn't leak
+ unset_hf_deepspeed_config()
+
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
@@ -195,28 +203,7 @@ def test_init_zero3_fp16(self):
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
-@require_deepspeed
-@require_torch_gpu
-class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
- """
-
- This class is for testing directly via get_regression_trainer
-
- It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
- which we can re-use here.
-
- Important: this class' setup can only work with a single gpu because it runs within the current
- pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
-
- Note: if any of the tests of this class get run there will be at least one gpu occupied by them
- until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
- won't be released until this pytest worker exits.
-
- This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
- processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
- is not a bug.
- """
-
+class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):
super().setUp()
@@ -248,10 +235,39 @@ def setUp(self):
zero3=config_zero3,
)
+ def tearDown(self):
+ super().tearDown()
+
+ # reset the ds config global so that tests state doesn't leak
+ unset_hf_deepspeed_config()
+
def get_config_dict(self, stage):
# As some tests modify the dict, always make a copy
return deepcopy(self.ds_config_dict[stage])
+
+@require_deepspeed
+@require_torch_gpu
+class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
+ """
+
+ This class is for testing directly via get_regression_trainer
+
+ It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
+ which we can re-use here.
+
+ Important: this class' setup can only work with a single gpu because it runs within the current
+ pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
+
+ Note: if any of the tests of this class get run there will be at least one gpu occupied by them
+ until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
+ won't be released until this pytest worker exits.
+
+ This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
+ processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
+ is not a bug.
+ """
+
# --- These tests are enough to run on one of zero stages --- #
def test_hf_ds_config_mismatch(self):
@@ -522,7 +538,7 @@ def test_gradient_accumulation(self, stage, dtype):
# see the note above how to get identical loss on a small bs
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
- def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
+ def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
@@ -534,7 +550,8 @@ def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
else:
raise ValueError(f"unknown stage {stage}")
- ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
+ if dtype == "bf16":
+ ds_file_list.append("bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
@@ -578,7 +595,7 @@ def test_save_checkpoints(self, stage, dtype):
trainer.train()
total = int(self.n_epochs * 64 / self.batch_size)
- self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
+ self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage, dtype)
@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_errors(self, stage, dtype):
@@ -724,6 +741,94 @@ def test_config_object(self):
self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
+ @parameterized.expand(params, name_func=parameterized_custom_name_func)
+ def test_load_best_model(self, stage, dtype):
+ # Test that forced deepspeed reinit doesn't break the model. the forced re-init after
+ # loading the best model in Trainer is there to workaround this bug in Deepspeed
+ # https://github.com/microsoft/DeepSpeed/issues/1612
+ #
+ # The test is derived from a repro script submitted in this Issue:
+ # https://github.com/huggingface/transformers/issues/17114
+ #
+ # One additional feature of this test is that we use a non-AdamW optimizer to test that
+ # deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
+ # correctly
+
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer # noqa
+
+ output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
+
+ ds_config_dict = self.get_config_dict(stage)
+ del ds_config_dict["optimizer"] # will use HF Trainer optimizer
+ del ds_config_dict["scheduler"] # will use HF Trainer scheduler
+ # must use this setting to get the reload path exercised
+ ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
+
+ with mockenv_context(**self.dist_env_1_gpu):
+
+ args_dict = {
+ "per_gpu_train_batch_size": 1,
+ "per_gpu_eval_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "learning_rate": 1e-4,
+ "num_train_epochs": 1,
+ "do_train": True,
+ "do_eval": True,
+ "optim": "adafactor",
+ "evaluation_strategy": "steps",
+ "eval_steps": 1,
+ "save_strategy": "steps",
+ "save_steps": 1,
+ "load_best_model_at_end": True,
+ "max_steps": 1,
+ "deepspeed": ds_config_dict,
+ }
+
+ training_args = TrainingArguments(output_dir, **args_dict)
+ tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
+ model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
+
+ def _add_eos_to_examples(example):
+ example["input_text"] = f"question: {example['question']} context: {example['context']}"
+ example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
+ return example
+
+ def _convert_to_features(example_batch):
+ input_encodings = tokenizer.batch_encode_plus(
+ example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
+ )
+ target_encodings = tokenizer.batch_encode_plus(
+ example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
+ )
+
+ encodings = {
+ "input_ids": input_encodings["input_ids"],
+ "attention_mask": input_encodings["attention_mask"],
+ "labels": target_encodings["input_ids"],
+ }
+
+ return encodings
+
+ def get_dataset():
+ data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
+ data_files = dict(train=data_file, validation=data_file)
+ raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
+ train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
+ valid_dataset = deepcopy(train_dataset)
+ return train_dataset, valid_dataset
+
+ train_dataset, eval_dataset = get_dataset()
+
+ trainer = Trainer(
+ model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+ trainer.train() # crash 1 was here
+ trainer.evaluate() # crash 2 was here
+
@slow
@require_deepspeed
@@ -815,7 +920,7 @@ def test_resume_train_not_from_ds_checkpoint(self, stage, dtype):
@require_torch_multi_gpu
@parameterized.expand(["bf16", "fp16", "fp32"])
def test_inference(self, dtype):
- if dtype == "bf16" and not is_torch_bf16_available():
+ if dtype == "bf16" and not is_torch_bf16_gpu_available():
self.skipTest("test requires bfloat16 hardware support")
# this is just inference, so no optimizer should be loaded
@@ -1034,50 +1139,3 @@ def test_clm_from_config_zero3_fp16(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
-
- @parameterized.expand(params, name_func=parameterized_custom_name_func)
- def test_load_best_model(self, stage, dtype):
- # this test exercises --load_best_model_at_end - the key is being able to resume after some training
-
- data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
- output_dir = self.get_auto_remove_tmp_dir()
- args = f"""
- --model_name_or_path {T5_TINY}
- --tokenizer_name {T5_TINY}
- --train_file {data_dir}/train.json
- --validation_file {data_dir}/val.json
- --output_dir {output_dir}
- --overwrite_output_dir
- --source_lang en
- --target_lang ro
- --do_train
- --max_train_samples 3
- --do_eval
- --max_eval_samples 1
- --logging_strategy steps
- --logging_steps 1
- --evaluation_strategy steps
- --eval_steps 1
- --save_strategy steps
- --save_steps 1
- --load_best_model_at_end
- --per_device_train_batch_size 1
- --per_device_eval_batch_size 1
- --num_train_epochs 1
- --report_to none
- """.split()
- args.extend(["--source_prefix", "translate English to Romanian: "])
-
- args.extend([f"--{dtype}"])
-
- ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
- script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
- launcher = get_launcher(distributed=False)
-
- cmd = launcher + script + args + ds_args
- # keep for quick debug
- # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
- with CaptureStd() as cs:
- execute_subprocess_async(cmd, env=self.get_env())
- # enough to test it didn't fail
- self.assertIn("DeepSpeed info", cs.out)
diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py
index 94559604935a..ac33b7f5a279 100644
--- a/tests/deepspeed/test_model_zoo.py
+++ b/tests/deepspeed/test_model_zoo.py
@@ -42,51 +42,100 @@
set_seed(42)
+FIXTURE_DIRECTORY = get_tests_dir("fixtures")
+ROOT_DIRECTORY = os.path.join(dirname(get_tests_dir()))
+DS_TESTS_DIRECTORY = dirname(os.path.abspath(__file__))
+
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
-# translation
-FSMT_TINY = "stas/tiny-wmt19-en-de"
-BART_TINY = "sshleifer/bart-tiny-random"
T5_SMALL = "t5-small"
-T5_TINY = "patrickvonplaten/t5-tiny-random"
-MBART_TINY = "sshleifer/tiny-mbart"
-MARIAN_TINY = "sshleifer/tiny-marian-en-de"
-
-# summarization
-PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
-# causal lm
+# *** Working Models ***
+ALBERT_TINY = "hf-internal-testing/tiny-albert"
+BART_TINY = "sshleifer/bart-tiny-random"
+BERT_TINY = "hf-internal-testing/tiny-bert"
+BIGBIRD_PEGASUS_TINY = "hf-internal-testing/tiny-random-bigbird_pegasus"
+BIG_BIRD_TINY = "hf-internal-testing/tiny-random-big_bird"
+BLENDERBOT_TINY = "hf-internal-testing/tiny-random-blenderbot"
+BLOOM_TINY = "bigscience/bigscience-small-testing"
+DEBERTA_TINY = "hf-internal-testing/tiny-random-deberta"
+DEBERTA_V2_TINY = "hf-internal-testing/tiny-random-deberta-v2"
+DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
+ELECTRA_TINY = "hf-internal-testing/tiny-electra"
+FLAUBERT_TINY = "hf-internal-testing/tiny-random-flaubert"
+FSMT_TINY = "stas/tiny-wmt19-en-de"
+FUNNEL_TINY = "hf-internal-testing/tiny-random-funnel"
GPT2_TINY = "sshleifer/tiny-gpt2"
+GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
+GPT_NEO_TINY = "hf-internal-testing/tiny-random-gpt_neo"
+LAYOUTLM_TINY = "hf-internal-testing/tiny-layoutlm"
+LED_TINY = "hf-internal-testing/tiny-random-led"
+LONGFORMER_TINY = "hf-internal-testing/tiny-random-longformer"
+M2M_100_TINY = "stas/tiny-m2m_100" # hf tiny model is unsuitable
+MARIAN_TINY = "sshleifer/tiny-marian-en-de"
+MBART_TINY = "sshleifer/tiny-mbart"
+MOBILEBERT_TINY = "hf-internal-testing/tiny-random-mobilebert"
+MPNET_TINY = "hf-internal-testing/tiny-random-mpnet"
+PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
+PROPHETNET_TINY = "hf-internal-testing/tiny-random-prophetnet"
+ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
+SQUEEZEBERT_TINY = "hf-internal-testing/tiny-random-squeezebert"
+T5_TINY = "patrickvonplaten/t5-tiny-random"
+T5_V1_TINY = "hf-internal-testing/tiny-random-t5-v1.1"
+VIT_TINY = "hf-internal-testing/tiny-random-vit"
XLM_ROBERTA_TINY = "hf-internal-testing/tiny-xlm-roberta"
+XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
-# question-answering
-ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
-# masked lm
-DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
-ELECTRA_TINY = "hf-internal-testing/tiny-electra"
+# *** To Fix ***
-# classification
-XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
-BERT_TINY = "hf-internal-testing/tiny-bert"
-FIXTURE_DIRECTORY = get_tests_dir("fixtures")
-ROOT_DIRECTORY = os.path.join(dirname(get_tests_dir()))
+# *** tiny model issues ***
+# missing model files:
+MT5_TINY = "hf-internal-testing/tiny-random-mt5"
+CAMEMBERT_TINY = "hf-internal-testing/tiny-random-camembert"
+OPENAI_GPT_TINY = "hf-internal-testing/tiny-random-openai-gpt"
+
+# missing tokenizer files
+CONVBERT_TINY = "hf-internal-testing/tiny-random-convbert"
+LAYOUTLMV2_TINY = "hf-internal-testing/tiny-random-layoutlmv2"
+HUBERT_TINY = "hf-internal-testing/tiny-random-hubert"
+
+# issues with tokenizer
+CTRL_TINY = "hf-internal-testing/tiny-random-ctrl"
+TRANSFO_XL_TINY = "hf-internal-testing/tiny-random-transfo-xl" # same as ctrl
-# TODO: to add:
-# albert
-# deberta
-# funnel
-# longformer
-# dpr
-# gpt_neo
-# camembert
-# deberta-v2
-# m2m_100
-# tapas
-# vit
-# big_bird
+# other issues with tiny models
+IBERT_TINY = "hf-internal-testing/tiny-random-ibert" # multiple issues with either mlm/qa/clas
+REFORMER_TINY = "hf-internal-testing/tiny-random-reformer" # multiple issues with either mlm/qa/clas
+
+# *** Lacking official examples to test with ***
+# or not working with examples
+DPR_TINY = "hf-internal-testing/tiny-random-dpr"
+# - "dpr" examples/research_projects/rag-end2end-retriever/
+RAG_TINY = "hf-internal-testing/tiny-random-rag"
+# - "rag" research_projects
+LUKE_TINY = ""
+# - "luke" Entities classes - no plan to make such example
+LXMERT_TINY = "hf-internal-testing/tiny-random-lxmert"
+# - "lxmert" doesn't work with run_qa.py
+CLIP_TINY = "hf-internal-testing/tiny-random-clip"
+# - "clip" nothing under pytorch examples - XXX: Suraj is working on adding some - check by end of Sep
+SPEECH_TO_TEXT_TINY = "hf-internal-testing/tiny-random-speech_to_text"
+# - "speech_to_text", nothing under pytorch examples
+
+
+# *** Reactive mode ***
+# models with low usage, unstable API, things about to change - do nothing about the following until someone runs into a problem
+TAPAS_TINY = "hf-internal-testing/tiny-random-tapas"
+# additional notes on tapas
+# 1. requires torch_scatter - skip if it's not installed?
+# 2. "Table must be of type pd.DataFrame" failure
+
+
+# TODO: new models to add:
+#
def get_launcher(distributed=False):
@@ -113,35 +162,69 @@ def make_task_cmds():
--overwrite_output_dir
""".split()
- # XXX: try to cover as many models as possible once (it's enough to run on one task per model)
+ # try to cover as many models as possible once (it's enough to run on one task per model)
# but need a tiny model for each
#
- # should have T5_TINY, etc. global var defined
+ # should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models = dict(
trans=[
"bart",
"fsmt",
+ "m2m_100",
"marian",
"mbart",
"t5",
+ "t5_v1",
+ # "mt5", missing model files
],
sum=[
"pegasus",
],
clm=[
+ "big_bird",
+ "bigbird_pegasus",
+ "blenderbot",
+ "bloom",
"gpt2",
+ "gpt_neo",
+ "gptj",
"xlm-roberta",
+ "prophetnet",
+ # "camembert", missing model files
],
mlm=[
- "electra",
+ "albert",
+ "deberta",
+ "deberta-v2",
"distilbert",
+ "electra",
+ "flaubert",
+ "funnel",
+ "layoutlm",
+ # "reformer", # multiple issues with either mlm/qa/clas
],
qa=[
+ "led",
+ "longformer",
+ "mobilebert",
+ "mpnet",
"roberta",
+ "squeezebert",
+ # "convbert", # missing tokenizer files
+ # "layoutlmv2", missing model files
],
clas=[
"bert",
"xlnet",
+ # "hubert", # missing tokenizer files
+ # "ibert", # multiple issues with either mlm/qa/clas
+ # "transfo-xl", # tokenizer issues as ctrl
+ # "ctrl", # tokenizer issues
+ # "openai-gpt", missing model files
+ # "tapas", multiple issues
+ ],
+ img_clas=[
+ "vit",
],
)
@@ -180,6 +263,13 @@ def make_task_cmds():
--max_seq_length 12
--task_name MRPC
""",
+ img_clas=f"""
+ {scripts_dir}/image-classification/run_image_classification.py
+ --dataset_name hf-internal-testing/cats_vs_dogs_sample
+ --remove_unused_columns False
+ --max_steps 10
+ --feature_extractor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json
+ """,
)
launcher = get_launcher(distributed=True)
@@ -216,7 +306,7 @@ def make_task_cmds():
#
# dtypes = [FP16]
# so just hardcoding --fp16 for now
-# if is_torch_bf16_available():
+# if is_torch_bf16_gpu_available():
# dtypes += [BF16]
diff --git a/tests/deepspeed/vit_feature_extractor.json b/tests/deepspeed/vit_feature_extractor.json
new file mode 100644
index 000000000000..bfe5a331249f
--- /dev/null
+++ b/tests/deepspeed/vit_feature_extractor.json
@@ -0,0 +1,4 @@
+{
+ "feature_extractor_type": "ViTFeatureExtractor",
+ "size": 30
+}
diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py
index 3d88ebda4559..64c244ae8ed2 100644
--- a/tests/extended/test_trainer_ext.py
+++ b/tests/extended/test_trainer_ext.py
@@ -105,6 +105,7 @@ def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
# test --sharded_ddp w/o --fp16
+ @unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp(self):
@@ -118,6 +119,7 @@ def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero_dp_2 w/o --fp16
+ @unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
@@ -278,7 +280,8 @@ def train_and_return_metrics(optim: str) -> Tuple[int, float]:
self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
- f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
+ f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were"
+ f" {gpu_total_mem_diff_bytes}",
)
def run_trainer(
diff --git a/tests/generation/test_generation_beam_search.py b/tests/generation/test_generation_beam_search.py
index 3971dcc79c35..885cefa62cbd 100644
--- a/tests/generation/test_generation_beam_search.py
+++ b/tests/generation/test_generation_beam_search.py
@@ -126,7 +126,11 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
- beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
+ beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
+ beam_indices = tuple(tuple(b) for b in beam_indices)
+ beam_scorer.process(
+ input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
+ )
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
@@ -136,7 +140,7 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
- input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
+ input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
@@ -161,10 +165,15 @@ def cut_expected_tensor(tensor):
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
+ expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
- input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
+ input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
+ )
+ self.parent.assertListEqual(
+ expected_beam_indices + [next_indices[batch_idx, 1].item()],
+ torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
@@ -188,6 +197,8 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
+ beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
+ beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
@@ -196,6 +207,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
+ beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
@@ -225,6 +237,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
+ beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
@@ -394,7 +407,7 @@ def cut_expected_tensor(tensor):
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
- input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
+ input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
def check_constrained_beam_scorer_finalize(
@@ -464,7 +477,7 @@ def check_constrained_beam_scorer_finalize(
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
- for (output, constraint) in [(s, c) for s in sequences for c in constraints]:
+ for output, constraint in [(s, c) for s in sequences for c in constraints]:
forced_token_ids = constraint.token_ids
if isinstance(forced_token_ids[0], list):
# disjunctive case
diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py
index 6006dbe21cdf..56227403ae60 100644
--- a/tests/generation/test_generation_utils.py
+++ b/tests/generation/test_generation_utils.py
@@ -1626,6 +1626,32 @@ def test_top_k_top_p_filtering(self):
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
+ # tests whether the function uses filter_value instead of default -inf
+ def test_top_k_top_p_filtering_with_filter_value(self):
+ logits = torch.tensor(
+ [
+ [
+ 1,
+ 1,
+ 1,
+ 0.99, # get filtered by top-p filtering
+ 0.98, # get filtered by top-k filtering
+ ]
+ ],
+ dtype=torch.float,
+ device=torch_device,
+ )
+
+ expected_output = torch.tensor(
+ [[1, 1, 1, 0, 0]],
+ dtype=torch.float,
+ device=torch_device,
+ )
+
+ output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
+
+ self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
+
@require_torch
class GenerationIntegrationTests(unittest.TestCase):
@@ -1654,8 +1680,12 @@ def test_diverse_beam_search(self):
self.assertListEqual(
generated_text,
[
- "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
- "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
+ "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the"
+ " middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle"
+ " name, as well as his father's first. It is the first baby for both of them.",
+ "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the"
+ " first child for both. The couple announced the pregnancy in January. The name Silas is the middle"
+ " name of Timberlake's maternal grandfather. It's also his own middle name.",
],
)
@@ -1993,8 +2023,8 @@ def test_max_new_tokens_encoder_decoder(self):
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
- # max_new_tokens and max_length serve the same purpose and should not be used together.
- with self.assertWarns(UserWarning):
+ # max_new_tokens and max_length serve the same purpose and must not be used together.
+ with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only(self):
@@ -2020,8 +2050,8 @@ def test_max_new_tokens_decoder_only(self):
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
- # max_new_tokens and max_length serve the same purpose and should not be used together.
- with self.assertWarns(UserWarning):
+ # max_new_tokens and max_length serve the same purpose and must not be used together.
+ with self.assertRaises(ValueError):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):
@@ -2318,6 +2348,94 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
+ @slow
+ def test_transition_scores_early_stopping(self):
+ # This is an aggressive test that makes sure that `beam_search's`
+ # transition scores are computed correctly for varying `num_return_sequences`,
+ # `num_beams` and `batch_size > 1`
+ # 2 x input_ids for "question: How are you? \n context: I had a long day, "
+ input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
+ torch_device
+ )
+
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
+
+ result = model.generate(
+ input_ids,
+ max_length=10,
+ return_dict_in_generate=True,
+ output_scores=True,
+ forced_eos_token_id=model.config.eos_token_id,
+ num_beams=4,
+ do_sample=False,
+ num_return_sequences=3,
+ length_penalty=0.0,
+ )
+
+ transition_scores = model.compute_transition_beam_scores(
+ sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
+ )
+
+ sum_transition_scores = torch.sum(transition_scores, dim=1)
+
+ self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
+
+ def test_log_scores_sample_decoder_only(self):
+ articles = ["I need input_ids to generate", "Short and"]
+ tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+
+ inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
+
+ result = model.generate(
+ **inputs,
+ max_length=15,
+ return_dict_in_generate=True,
+ do_sample=False,
+ output_scores=True,
+ )
+
+ # decoder-only starts generating from `input_ids`
+ begin_generation = inputs.input_ids.shape[-1]
+
+ gen_sequences = result.sequences[:, begin_generation:]
+ probs = torch.stack(result.scores, dim=1).softmax(-1)
+
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
+ expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
+
+ self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
+
+ def test_log_scores_sample_encoder_decoder(self):
+ articles = ["I need input_ids to generate", "Short and"]
+ tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
+
+ inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
+
+ result = model.generate(
+ **inputs,
+ max_length=3,
+ return_dict_in_generate=True,
+ do_sample=False,
+ num_beams=1,
+ output_scores=True,
+ )
+
+ # encoder-decoder has one decoder_start_token_id by default
+ begin_generation = 1
+
+ gen_sequences = result.sequences[:, begin_generation:]
+ probs = torch.stack(result.scores, dim=1).softmax(-1)
+
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
+ expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
+
+ self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
+
@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
@@ -2362,8 +2480,8 @@ def test_beam_search_example_integration(self):
@slow
def test_constrained_beam_search(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
@@ -2392,14 +2510,15 @@ def test_constrained_beam_search(self):
self.assertListEqual(
generated_text,
[
- "The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
+ "The soldiers were not prepared and didn't know what to do. They had no idea how they would react if"
+ " the enemy attacked them, big weapons scared"
],
)
@slow
def test_constrained_beam_search_mixed(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
@@ -2430,15 +2549,16 @@ def test_constrained_beam_search_mixed(self):
self.assertListEqual(
generated_text,
[
- "The soldiers, who were all scared and screaming at each other as they tried to get out of the",
- "The child was taken to a local hospital where she screamed and scared for her life, police said.",
+ "The soldiers, who had been stationed at the base for more than a year before being evacuated"
+ " screaming scared",
+ "The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
],
)
@slow
def test_constrained_beam_search_mixed_mixin(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
@@ -2466,8 +2586,9 @@ def test_constrained_beam_search_mixed_mixin(self):
self.assertListEqual(
generated_text,
[
- "The soldiers, who were all scared and screaming at each other as they tried to get out of the",
- "The child was taken to a local hospital where she screamed and scared for her life, police said.",
+ "The soldiers, who had been stationed at the base for more than a year before being evacuated"
+ " screaming scared",
+ "The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
],
)
@@ -2493,7 +2614,7 @@ def test_constrained_beam_search_example_translation_mixin(self):
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- self.assertListEqual(outputs, ["Wie alter sind Sie?"])
+ self.assertListEqual(outputs, ["Wie alt sind Sie?"])
@slow
def test_constrained_beam_search_example_integration(self):
@@ -2537,11 +2658,11 @@ def test_constrained_beam_search_example_integration(self):
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- self.assertListEqual(outputs, ["Wie alter sind Sie?"])
+ self.assertListEqual(outputs, ["Wie alt sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
+ model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py
index eeb10ad2d315..2695082c412d 100644
--- a/tests/models/auto/test_configuration_auto.py
+++ b/tests/models/auto/test_configuration_auto.py
@@ -14,6 +14,8 @@
# limitations under the License.
import importlib
+import json
+import os
import sys
import tempfile
import unittest
@@ -56,14 +58,14 @@ def test_config_for_model_str(self):
self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self):
- """
- In cases where config.json doesn't include a model_type,
- perform a few safety checks on the config mapping's order.
- """
- # no key string should be included in a later key string (typical failure case)
- keys = list(CONFIG_MAPPING.keys())
- for i, key in enumerate(keys):
- self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # This model name contains bert and roberta, but roberta ends up being picked.
+ folder = os.path.join(tmp_dir, "fake-roberta")
+ os.makedirs(folder, exist_ok=True)
+ with open(os.path.join(folder, "config.json"), "w") as f:
+ f.write(json.dumps({}))
+ config = AutoConfig.from_pretrained(folder)
+ self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self):
try:
diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py
index 26122e6164ab..2f99d5c379bc 100644
--- a/tests/models/auto/test_processor_auto.py
+++ b/tests/models/auto/test_processor_auto.py
@@ -21,7 +21,7 @@
from pathlib import Path
from shutil import copyfile
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import (
CONFIG_MAPPING,
@@ -36,7 +36,7 @@
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
-from transformers.testing_utils import PASS, USER, get_tests_dir, is_staging_test
+from transformers.testing_utils import TOKEN, USER, get_tests_dir, is_staging_test
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available
@@ -209,22 +209,24 @@ class ProcessorPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-processor")
+ delete_repo(token=cls._token, repo_id="test-processor")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-processor-org")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-processor")
+ delete_repo(token=cls._token, repo_id="test-dynamic-processor")
except HTTPError:
pass
diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py
index 18fc66a4f563..5ef86523ebd6 100644
--- a/tests/models/bart/test_modeling_bart.py
+++ b/tests/models/bart/test_modeling_bart.py
@@ -116,6 +116,12 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@@ -145,11 +151,14 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
+ config.vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@@ -413,6 +422,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -521,8 +531,47 @@ def xsum_1_1_model(self):
def test_xsum_1_1_generation(self):
hf = self.xsum_1_1_model
tok = self.tok
- ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
- EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
+ ARTICLE = (
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes."
+ )
+ EXPECTED = (
+ " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
dct = tok(ARTICLE, return_tensors="pt")
generated_ids = hf.generate(**dct, num_beams=4)
@@ -534,8 +583,116 @@ def test_xsum_1_1_batch_generation(self):
batch = self.tok(
[
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
],
return_tensors="pt",
padding="longest",
@@ -545,11 +702,13 @@ def test_xsum_1_1_batch_generation(self):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
assert (
result[0]
- == " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
+ == " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
)
assert (
result[1]
- == " An investigation into the crash that killed at least 10 people in the French capital has been released by the French police investigating the crash."
+ == " An investigation into the crash that killed at least 10 people in the French capital has been"
+ " released by the French police investigating the crash."
)
def test_encoder_equiv(self):
@@ -557,8 +716,116 @@ def test_encoder_equiv(self):
batch = self.tok(
[
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
],
return_tensors="pt",
padding="longest",
@@ -641,7 +908,10 @@ def test_xsum_summarization_same_as_fairseq(self):
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
- EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
+ EXPECTED_SUMMARY = (
+ "California's largest power company has begun shutting off electricity to thousands of customers in the"
+ " state."
+ )
dct = tok.batch_encode_plus(
[PGE_ARTICLE],
max_length=1024,
@@ -679,14 +949,197 @@ def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
tok = BartTokenizer.from_pretrained("facebook/bart-large")
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noq
+ FRANCE_ARTICLE = ( # @noq
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
# The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
dct = tok.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
@@ -943,6 +1396,7 @@ def prepare_config_and_inputs_for_common(self):
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
+ fx_comptatible = True
test_pruning = False
is_encoder_decoder = False
diff --git a/tests/models/bart/test_modeling_flax_bart.py b/tests/models/bart/test_modeling_flax_bart.py
index ef4f9d38525f..54a6ff4534df 100644
--- a/tests/models/bart/test_modeling_flax_bart.py
+++ b/tests/models/bart/test_modeling_flax_bart.py
@@ -420,7 +420,10 @@ def test_summarization_fast(self):
model = FlaxBartForConditionalGeneration.from_pretrained("sshleifer/distilbart-cnn-6-6")
tokenizer = BartTokenizer.from_pretrained("sshleifer/distilbart-cnn-6-6")
- input_str = "This sentence is made of three parts. Each part is important on its own. One part is about animals, the other part about planes, and the last part about housing."
+ input_str = (
+ "This sentence is made of three parts. Each part is important on its own. One part is about animals, the"
+ " other part about planes, and the last part about housing."
+ )
input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
@@ -436,14 +439,197 @@ def test_cnn_summarization_same_as_fairseq(self):
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noq
+ FRANCE_ARTICLE = ( # @noq
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
# The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
dct = tokenizer.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
@@ -463,10 +649,21 @@ def test_cnn_summarization_same_as_fairseq(self):
assert (hypotheses_batch[:, 1] == 0).all().item()
EXPECTED = [
- "A French prosecutor says he is not aware of any video footage from on board the plane. Two German magazines claim to have found a cell phone video showing the crash. The publications say they watched the video, which was found by a source close to the investigation. All 150 on board the Germanwings flight were killed.",
- "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice.",
- "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. Bergen: The most misleading assertion is that the negotiations' objective at the outset was the total elimination of any nuclear program.",
- "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the Bronx on Friday. If convicted, Barrientos faces up to four years in prison.",
+ "A French prosecutor says he is not aware of any video footage from on board the plane. Two German"
+ " magazines claim to have found a cell phone video showing the crash. The publications say they watched"
+ " the video, which was found by a source close to the investigation. All 150 on board the Germanwings"
+ " flight were killed.",
+ "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court"
+ " jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the"
+ " Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a"
+ " move toward greater justice.",
+ "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The"
+ " debate that has already begun will likely result in more heat than light. Bergen: The most misleading"
+ " assertion is that the negotiations' objective at the outset was the total elimination of any nuclear"
+ " program.",
+ "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors"
+ " say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the"
+ " Bronx on Friday. If convicted, Barrientos faces up to four years in prison.",
]
generated_summaries = tokenizer.batch_decode(
diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py
index 29c61a1e40e7..5e5c5ee592a1 100644
--- a/tests/models/bart/test_modeling_tf_bart.py
+++ b/tests/models/bart/test_modeling_tf_bart.py
@@ -18,7 +18,7 @@
import numpy as np
from transformers import BartConfig, BartTokenizer, is_tf_available
-from transformers.testing_utils import require_tf, slow
+from transformers.testing_utils import require_tf, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -125,8 +125,22 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
- output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
- output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
+ decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
+ output_from_no_past = model(
+ next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
+ )
+ output_from_no_past = output_from_no_past[0]
+
+ decoder_position_ids = (
+ tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
+ )
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ past_key_values=past_key_values,
+ position_ids=decoder_position_ids,
+ )
+ output_from_past = output_from_past[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
@@ -279,25 +293,11 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
-def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
- """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
- if a is None and b is None:
- return True
- try:
- if tf.debugging.assert_near(a, b, atol=atol):
- return True
- raise
- except Exception:
- if len(prefix) > 0:
- prefix = f"{prefix}: "
- raise AssertionError(f"{prefix}{a} != {b}")
-
-
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
@@ -375,18 +375,221 @@ def test_cnn_summarization_same_as_fairseq_hard(self):
hf = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tok = self.tok
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- EXPECTED_SUMMARY_FRANCE = 'French prosecutor says he\'s not aware of any video footage from on board the plane. German daily Bild and French Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa confirms co-pilot Andreas Lubitz had battled depression.'
+ FRANCE_ARTICLE = ( # @noqa
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ EXPECTED_SUMMARY_FRANCE = (
+ "French prosecutor says he's not aware of any video footage from on board the plane. German daily Bild"
+ " and French Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie"
+ ' spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa confirms'
+ " co-pilot Andreas Lubitz had battled depression."
+ )
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- EXPECTED_SUMMARY_SHORTER = "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice."
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ EXPECTED_SUMMARY_SHORTER = (
+ "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives"
+ " the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States"
+ " opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said"
+ " it was a move toward greater justice."
+ )
# The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- EXPECTED_SUMMARY_IRAN = "The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. He says the agreement limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Bergen says the most important aim of a nuclear deal is preventing a nuclear Iran."
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ EXPECTED_SUMMARY_IRAN = (
+ "The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen:"
+ " The debate that has already begun will likely result in more heat than light. He says the agreement"
+ " limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon."
+ " Bergen says the most important aim of a nuclear deal is preventing a nuclear Iran."
+ )
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
- EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
+ EXPECTED_SUMMARY_SUBWAY = (
+ "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the"
+ " marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in"
+ " the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly"
+ " sneaking into the subway."
+ )
dct = tok(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
@@ -433,18 +636,221 @@ def xsum_1_1_model(self):
def test_xsum_1_1_generation(self):
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
- ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
- EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
+ ARTICLE = (
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes."
+ )
+ EXPECTED = (
+ " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
+ def test_xsum_1_1_xla_generation(self):
+ # same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
+ model = self.xsum_1_1_model
+ assert model.model.decoder.embed_tokens._layer == model.model.shared
+ ARTICLE = (
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes."
+ )
+ EXPECTED = (
+ " The International Criminal Court (ICC) has announced that it is to be investigated by the International"
+ " Criminal Court (ICC) over allegations of war crimes."
+ )
+
+ dct = self.tok(ARTICLE, return_tensors="tf")
+ generated_ids = model.generate(**dct, num_beams=4, no_repeat_ngram_size=0)
+ result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ assert result == EXPECTED
+
+ xla_generate = tf.function(model.generate, jit_compile=True)
+ generated_ids = xla_generate(**dct, num_beams=4, no_repeat_ngram_size=0)
+ result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ assert result == EXPECTED
+
def test_xsum_1_1_batch_generation(self):
batch = self.tok(
[
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
],
return_tensors="tf",
padding="longest",
@@ -454,18 +860,128 @@ def test_xsum_1_1_batch_generation(self):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
assert (
result[0]
- == " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
+ == " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
)
assert (
result[1]
- == " An investigation into the crash that killed at least 10 people in the French capital has been released by the French police investigating the crash."
+ == " An investigation into the crash that killed at least 10 people in the French capital has been"
+ " released by the French police investigating the crash."
)
def test_encoder_equiv(self):
batch = self.tok(
[
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
],
return_tensors="tf",
padding="longest",
diff --git a/tests/models/bart/test_tokenization_bart.py b/tests/models/bart/test_tokenization_bart.py
index b8e216e69ba2..24ea6e1e5cd9 100644
--- a/tests/models/bart/test_tokenization_bart.py
+++ b/tests/models/bart/test_tokenization_bart.py
@@ -112,14 +112,13 @@ def test_prepare_batch_empty_target_text(self):
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
- def test_as_target_tokenizer_target_length(self):
+ def test_tokenizer_as_target_length(self):
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
+ targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
self.assertEqual(32, targets["input_ids"].shape[1])
@require_torch
@@ -140,8 +139,7 @@ def test_special_tokens(self):
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
inputs = tokenizer(src_text, return_tensors="pt")
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(tgt_text, return_tensors="pt")
+ targets = tokenizer(text_target=tgt_text, return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py
index 8c9202c34b10..7d2d75d2881b 100644
--- a/tests/models/beit/test_modeling_beit.py
+++ b/tests/models/beit/test_modeling_beit.py
@@ -23,7 +23,7 @@
from transformers import BeitConfig
from transformers.models.auto import get_values
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -153,6 +153,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = BeitForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = BeitForSemanticSegmentation(config)
@@ -202,6 +212,11 @@ def test_config(self):
def test_inputs_embeds(self):
pass
+ @require_torch_multi_gpu
+ @unittest.skip(reason="BEiT has some layers using `add_module` which doesn't work well with `nn.DataParallel`")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/beit/test_modeling_flax_beit.py b/tests/models/beit/test_modeling_flax_beit.py
index 50996dedc7af..b37dd5bf36b4 100644
--- a/tests/models/beit/test_modeling_flax_beit.py
+++ b/tests/models/beit/test_modeling_flax_beit.py
@@ -105,7 +105,6 @@ def prepare_config_and_inputs(self):
return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels):
-
model = FlaxBeitModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
@@ -121,6 +120,13 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = FlaxBeitForImageClassification(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
diff --git a/tests/models/bert/test_tokenization_bert.py b/tests/models/bert/test_tokenization_bert.py
index fcb69914b94d..dfbcd266c499 100644
--- a/tests/models/bert/test_tokenization_bert.py
+++ b/tests/models/bert/test_tokenization_bert.py
@@ -187,7 +187,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/models/bert/test_tokenization_bert_tf.py b/tests/models/bert/test_tokenization_bert_tf.py
new file mode 100644
index 000000000000..4ace9c936093
--- /dev/null
+++ b/tests/models/bert/test_tokenization_bert_tf.py
@@ -0,0 +1,100 @@
+import unittest
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+from transformers import AutoConfig, TFAutoModel, is_tensorflow_text_available, is_tf_available
+from transformers.models.bert.tokenization_bert import BertTokenizer
+from transformers.testing_utils import require_tensorflow_text, slow
+
+
+if is_tensorflow_text_available():
+ from transformers.models.bert import TFBertTokenizer
+
+if is_tf_available():
+ import tensorflow as tf
+
+
+TOKENIZER_CHECKPOINTS = ["bert-base-uncased", "bert-base-cased"]
+TINY_MODEL_CHECKPOINT = "hf-internal-testing/tiny-bert-tf-only"
+
+if is_tf_available():
+
+ class ModelToSave(tf.keras.Model):
+ def __init__(self, tokenizer):
+ super().__init__()
+ self.tokenizer = tokenizer
+ config = AutoConfig.from_pretrained(TINY_MODEL_CHECKPOINT)
+ self.bert = TFAutoModel.from_config(config)
+
+ def call(self, inputs):
+ tokenized = self.tokenizer(inputs)
+ out = self.bert(**tokenized)
+ return out["pooler_output"]
+
+
+@require_tensorflow_text
+class BertTokenizationTest(unittest.TestCase):
+ # The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
+ # so that's what we focus on here.
+
+ def setUp(self):
+ super().setUp()
+
+ self.tokenizers = [BertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
+ self.tf_tokenizers = [TFBertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
+ self.test_sentences = [
+ "This is a straightforward English test sentence.",
+ "This one has some weird characters\rto\nsee\r\nif those\u00E9break things.",
+ "Now we're going to add some Chinese: äø äŗ äø äøäŗäø",
+ "And some much more rare Chinese: é½ å é½å ",
+ "Je vais aussi Ʃcrire en franƧais pour tester les accents",
+ "Classical Irish also has some unusual characters, so in they go: GaelaÄ, ź¼",
+ ]
+ self.paired_sentences = list(zip(self.test_sentences, self.test_sentences[::-1]))
+
+ def test_output_equivalence(self):
+ for tokenizer, tf_tokenizer in zip(self.tokenizers, self.tf_tokenizers):
+ for test_inputs in (self.test_sentences, self.paired_sentences):
+ python_outputs = tokenizer(test_inputs, return_tensors="tf", padding="longest")
+ tf_outputs = tf_tokenizer(test_inputs)
+
+ for key in python_outputs.keys():
+ self.assertTrue(tf.reduce_all(python_outputs[key].shape == tf_outputs[key].shape))
+ self.assertTrue(tf.reduce_all(tf.cast(python_outputs[key], tf.int64) == tf_outputs[key]))
+
+ @slow
+ def test_different_pairing_styles(self):
+ for tf_tokenizer in self.tf_tokenizers:
+ merged_outputs = tf_tokenizer(self.paired_sentences)
+ separated_outputs = tf_tokenizer(
+ text=[sentence[0] for sentence in self.paired_sentences],
+ text_pair=[sentence[1] for sentence in self.paired_sentences],
+ )
+ for key in merged_outputs.keys():
+ self.assertTrue(tf.reduce_all(tf.cast(merged_outputs[key], tf.int64) == separated_outputs[key]))
+
+ @slow
+ def test_graph_mode(self):
+ for tf_tokenizer in self.tf_tokenizers:
+ compiled_tokenizer = tf.function(tf_tokenizer)
+ for test_inputs in (self.test_sentences, self.paired_sentences):
+ test_inputs = tf.constant(test_inputs)
+ compiled_outputs = compiled_tokenizer(test_inputs)
+ eager_outputs = tf_tokenizer(test_inputs)
+
+ for key in eager_outputs.keys():
+ self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key]))
+
+ @slow
+ def test_saved_model(self):
+ for tf_tokenizer in self.tf_tokenizers:
+ model = ModelToSave(tokenizer=tf_tokenizer)
+ test_inputs = tf.convert_to_tensor(self.test_sentences)
+ out = model(test_inputs) # Build model with some sample inputs
+ with TemporaryDirectory() as tempdir:
+ save_path = Path(tempdir) / "saved.model"
+ model.save(save_path)
+ loaded_model = tf.keras.models.load_model(save_path)
+ loaded_output = loaded_model(test_inputs)
+ # We may see small differences because the loaded model is compiled, so we need an epsilon for the test
+ self.assertLessEqual(tf.reduce_max(tf.abs(out - loaded_output)), 1e-5)
diff --git a/tests/models/bert_generation/test_tokenization_bert_generation.py b/tests/models/bert_generation/test_tokenization_bert_generation.py
index 155f383a4600..581f249db050 100644
--- a/tests/models/bert_generation/test_tokenization_bert_generation.py
+++ b/tests/models/bert_generation/test_tokenization_bert_generation.py
@@ -144,7 +144,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
871,
419,
diff --git a/tests/models/bert_japanese/test_tokenization_bert_japanese.py b/tests/models/bert_japanese/test_tokenization_bert_japanese.py
index 59605bac1412..86b3f16f101e 100644
--- a/tests/models/bert_japanese/test_tokenization_bert_japanese.py
+++ b/tests/models/bert_japanese/test_tokenization_bert_japanese.py
@@ -176,7 +176,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "ććć«ć”ćÆ", "ćć", "ć«ć”ćÆ", "ć°ććÆ", "##ćć", "##ć«ć”ćÆ", "##ć°ććÆ"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -249,7 +249,7 @@ def test_character_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "ć", "ć", "ć«", "ć”", "ćÆ", "ć°", "äø", "ē", "ć", "ć"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -288,7 +288,8 @@ def test_tokenizer_mismatch_warning(self):
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function"
+ " is called from."
)
)
EXAMPLE_BERT_ID = "bert-base-cased"
@@ -296,6 +297,7 @@ def test_tokenizer_mismatch_warning(self):
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function"
+ " is called from."
)
)
diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py
index 90e6bbb90e17..ec59f8f93d6e 100644
--- a/tests/models/big_bird/test_modeling_big_bird.py
+++ b/tests/models/big_bird/test_modeling_big_bird.py
@@ -597,13 +597,13 @@ def test_for_change_to_full_attn(self):
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
# overwrite from common in order to skip the check on `attentions`
- def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
- if type(names) == str and names.startswith("attentions"):
+ if name.startswith("outputs.attentions"):
return
else:
- super().check_outputs(fx_outputs, pt_outputs, model_class, names)
+ super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@require_torch
@@ -799,7 +799,16 @@ def test_tokenizer_inference(self):
model.to(torch_device)
text = [
- "Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer. Longformerās attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA."
+ "Transformer-based models are unable to process long sequences due to their self-attention operation,"
+ " which scales quadratically with the sequence length. To address this limitation, we introduce the"
+ " Longformer with an attention mechanism that scales linearly with sequence length, making it easy to"
+ " process documents of thousands of tokens or longer. Longformerās attention mechanism is a drop-in"
+ " replacement for the standard self-attention and combines a local windowed attention with a task"
+ " motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer"
+ " on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In"
+ " contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream"
+ " tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new"
+ " state-of-the-art results on WikiHop and TriviaQA."
]
inputs = tokenizer(text)
@@ -837,7 +846,18 @@ def test_inference_question_answering(self):
)
model.to(torch_device)
- context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."
+ context = (
+ "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and"
+ " Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago"
+ " and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a"
+ " sparse-attention based transformer which extends Transformer based models, such as BERT to much longer"
+ " sequences. In addition to sparse attention, BigBird also applies global attention as well as random"
+ " attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and"
+ " random attention approximates full attention, while being computationally much more efficient for longer"
+ " sequences. As a consequence of the capability to handle longer context, BigBird has shown improved"
+ " performance on various long document NLP tasks, such as question answering and summarization, compared"
+ " to BERT or RoBERTa."
+ )
question = [
"Which is better for longer sequences- BigBird or BERT?",
diff --git a/tests/models/big_bird/test_modeling_flax_big_bird.py b/tests/models/big_bird/test_modeling_flax_big_bird.py
index 5c5452441e0b..7c4c7267216a 100644
--- a/tests/models/big_bird/test_modeling_flax_big_bird.py
+++ b/tests/models/big_bird/test_modeling_flax_big_bird.py
@@ -40,7 +40,7 @@ class FlaxBigBirdModelTester(unittest.TestCase):
def __init__(
self,
parent,
- batch_size=13,
+ batch_size=2,
seq_length=56,
is_training=True,
use_attention_mask=True,
@@ -48,9 +48,9 @@ def __init__(
use_labels=True,
vocab_size=99,
hidden_size=32,
- num_hidden_layers=5,
- num_attention_heads=4,
- intermediate_size=37,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ intermediate_size=7,
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
@@ -62,7 +62,7 @@ def __init__(
attention_type="block_sparse",
use_bias=True,
rescale_embeddings=False,
- block_size=4,
+ block_size=2,
num_random_blocks=3,
):
self.parent = parent
@@ -156,10 +156,30 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FlaxBigBirdModelTester(self)
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_from_pretrained_save_pretrained(self):
+ super().test_from_pretrained_save_pretrained()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_from_pretrained_with_no_automatic_init(self):
+ super().test_from_pretrained_with_no_automatic_init()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_no_automatic_init(self):
+ super().test_no_automatic_init()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_hidden_states_output(self):
+ super().test_hidden_states_output()
+
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
- model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True)
+ model = model_class_name.from_pretrained("google/bigbird-roberta-base")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@@ -194,10 +214,10 @@ def model_jitted(input_ids, attention_mask=None, **kwargs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
- def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
- if type(names) == str and names.startswith("attentions"):
+ if name.startswith("outputs.attentions"):
return
else:
- super().check_outputs(fx_outputs, pt_outputs, model_class, names)
+ super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
diff --git a/tests/models/big_bird/test_tokenization_big_bird.py b/tests/models/big_bird/test_tokenization_big_bird.py
index 29c28d5877d2..ff6545100825 100644
--- a/tests/models/big_bird/test_tokenization_big_bird.py
+++ b/tests/models/big_bird/test_tokenization_big_bird.py
@@ -168,7 +168,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
# fmt: off
original_tokenizer_encodings = [65, 871, 419, 358, 946, 991, 2521, 452, 358, 1357, 387, 7751, 3536, 112, 985, 456, 126, 865, 938, 5400, 5734, 458, 1368, 467, 786, 2462, 5246, 1159, 633, 865, 4519, 457, 582, 852, 2557, 427, 916, 508, 405, 34324, 497, 391, 408, 11342, 1244, 385, 100, 938, 985, 456, 574, 362, 12597, 3200, 3129, 1172, 66] # noqa: E231
# fmt: on
diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
index 31f109fbcf61..d4e7e8f4ae42 100644
--- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
+++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
@@ -538,9 +538,26 @@ def test_seq_to_seq_generation(self):
hypotheses_batch = model.generate(**inputs)
- EXPECTED_LEP = "motivated by some recent studies on the light cp - odd higgs boson @xmath0 in non - minimal supersymmetric models, we investigate the rare @xmath1-decays @xmath2 ( @xmath3 ) in the two higgs doublet model ( 2hdm ), the nearly minimal supersymmetric standard model ( nmssm ), the next - to - minimal supersymmetric standard model ( nmssm ) and the minimal supersymmetric standard model ( mssm ). we find that the branching ratios of @xmath4 can reach @xmath5 in 2hdm, @xmath6 in nmssm and @xmath7 in mssm, which are at the level of @xmath8 in 2hdm, @xmath9 in nmssm and @xmath10 in mssm, respectively. these rates can be significantly enhanced in new physics models which lie within the expected sensitivity of the gigaz option of the international linear collider ( ilc ). = # 1,nucl. phys. b * # 1"
+ EXPECTED_LEP = (
+ "motivated by some recent studies on the light cp - odd higgs boson @xmath0 in non - minimal"
+ " supersymmetric models, we investigate the rare @xmath1-decays @xmath2 ( @xmath3 ) in the two higgs"
+ " doublet model ( 2hdm ), the nearly minimal supersymmetric standard model ( nmssm ), the next - to -"
+ " minimal supersymmetric standard model ( nmssm ) and the minimal supersymmetric standard model ( mssm"
+ " ). we find that the branching ratios of @xmath4 can reach @xmath5 in 2hdm, @xmath6 in nmssm and"
+ " @xmath7 in mssm, which are at the level of @xmath8 in 2hdm, @xmath9 in nmssm and @xmath10 in mssm,"
+ " respectively. these rates can be significantly enhanced in new physics models which lie within the"
+ " expected sensitivity of the gigaz option of the international linear collider ( ilc ). = # 1,nucl."
+ " phys. b * # 1"
+ )
- EXPECTED_MAGNET = "a positive, nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic - field range in the surface state of a topological insulator having a positive and finite effective g - factor. this linear magnetoresistance shows up in the system of high carrier concentration and low mobility when electrons are in extended states and spread over many smeared landau levels, and persists up to room temperature, providing a possible mechanism for the recently observed linear magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons."
+ EXPECTED_MAGNET = (
+ "a positive, nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic -"
+ " field range in the surface state of a topological insulator having a positive and finite effective g -"
+ " factor. this linear magnetoresistance shows up in the system of high carrier concentration and low"
+ " mobility when electrons are in extended states and spread over many smeared landau levels, and persists"
+ " up to room temperature, providing a possible mechanism for the recently observed linear"
+ " magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons."
+ )
generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py
index 6bf71384671c..9b10e7690c1c 100644
--- a/tests/models/blenderbot/test_modeling_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_blenderbot.py
@@ -107,6 +107,12 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
@@ -135,11 +141,14 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
+ config.vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@@ -218,6 +227,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -304,7 +314,10 @@ def test_generation_from_short_input_same_as_parlai_3B(self):
generated_txt = self.tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
assert generated_txt[0].strip() == tgt_text
- src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
+ src_text = (
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel"
+ " like i'm going to throw up.\nand why is that?"
+ )
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
diff --git a/tests/models/blenderbot/test_modeling_tf_blenderbot.py b/tests/models/blenderbot/test_modeling_tf_blenderbot.py
index a8ca54558f06..7b974cbe326a 100644
--- a/tests/models/blenderbot/test_modeling_tf_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_tf_blenderbot.py
@@ -17,7 +17,7 @@
import unittest
from transformers import BlenderbotConfig, BlenderbotTokenizer, is_tf_available
-from transformers.testing_utils import require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -213,8 +213,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):
diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
index b046fa97d9e9..f049fe3769a1 100644
--- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
@@ -107,6 +107,12 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
@@ -135,6 +141,8 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def prepare_config_and_inputs_for_common(self):
@@ -213,6 +221,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -290,8 +299,8 @@ def tokenizer(self):
def test_90_generation_from_long_input(self):
src_text = [
- "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like\
- i'm going to throw up.\nand why is that?"
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel"
+ " like i'm going to throw up.\nand why is that?"
]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
diff --git a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
index a830e6c0b6c8..0b8d6132a20a 100644
--- a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
@@ -17,7 +17,7 @@
import unittest
from transformers import BlenderbotSmallConfig, BlenderbotSmallTokenizer, is_tf_available
-from transformers.testing_utils import require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -278,8 +278,8 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
@@ -305,7 +305,8 @@ def _long_tensor(tok_lst):
@require_tf
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
src_text = [
- "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like "
+ " i'm going to throw up.\nand why is that?"
]
model_name = "facebook/blenderbot_small-90M"
diff --git a/tests/models/bloom/__init__.py b/tests/models/bloom/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py
new file mode 100644
index 000000000000..4570cb767326
--- /dev/null
+++ b/tests/models/bloom/test_modeling_bloom.py
@@ -0,0 +1,778 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+import unittest
+
+from transformers import BloomConfig, is_torch_available
+from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomTokenizerFast,
+ )
+
+
+@require_torch
+class BloomModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=14,
+ seq_length=7,
+ is_training=True,
+ use_token_type_ids=False,
+ use_input_mask=True,
+ use_labels=True,
+ use_mc_token_ids=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_token_type_ids = use_token_type_ids
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.use_mc_token_ids = use_mc_token_ids
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = None
+ self.bos_token_id = vocab_size - 1
+ self.eos_token_id = vocab_size - 1
+ self.pad_token_id = vocab_size - 1
+
+ def get_large_model_config(self):
+ return BloomConfig.from_pretrained("bigscience/bloom")
+
+ def prepare_config_and_inputs(self, gradient_checkpointing=False):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ sequence_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config(gradient_checkpointing=gradient_checkpointing)
+
+ return (config, input_ids, input_mask, sequence_labels)
+
+ def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
+ return BloomConfig(
+ vocab_size=self.vocab_size,
+ seq_length=self.seq_length,
+ hidden_size=self.hidden_size,
+ n_layer=self.num_hidden_layers,
+ n_head=self.num_attention_heads,
+ resid_pdrop=self.hidden_dropout_prob,
+ attn_pdrop=self.attention_probs_dropout_prob,
+ n_positions=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ use_cache=True,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ pad_token_id=self.pad_token_id,
+ num_labels=self.num_labels,
+ gradient_checkpointing=gradient_checkpointing,
+ slow_but_exact=slow_but_exact,
+ dtype="float32",
+ )
+
+ def create_and_check_bloom_model(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(len(result.past_key_values), config.n_layer)
+
+ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True)
+ outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids))
+ outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids))
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ past = outputs["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+ half_seq_length = self.seq_length // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
+
+ output, past = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
+ "last_hidden_state"
+ ]
+ self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
+ config.num_labels = self.num_labels
+ model = BloomForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
+ model = BloomForTokenClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
+ def create_and_check_forward_and_backwards(
+ self, config, input_ids, input_mask, *args, gradient_checkpointing=False
+ ):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ if gradient_checkpointing:
+ model.gradient_checkpointing_enable()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ result.loss.backward()
+
+ def create_and_check_bloom_weight_initialization(self, config, *args):
+ model = BloomModel(config)
+ model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)
+ for key in model.state_dict().keys():
+ if "c_proj" in key and "weight" in key:
+ self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
+ self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+
+ config, input_ids, input_mask, sequence_labels = config_and_inputs
+
+ inputs_dict = {"input_ids": input_ids}
+
+ return config, inputs_dict
+
+
+@require_torch
+class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ BloomModel,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
+ test_missing_keys = False
+ test_pruning = False
+ test_torchscript = True # torch.autograd functions seems to be not supported
+
+ def setUp(self):
+ self.model_tester = BloomModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=BloomConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_bloom_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model(*config_and_inputs)
+
+ def test_bloom_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past(*config_and_inputs)
+
+ def test_bloom_model_att_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_attention_mask_past(*config_and_inputs)
+
+ def test_bloom_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past_large_inputs(*config_and_inputs)
+
+ def test_bloom_lm_head_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
+
+ def test_bloom_sequence_classification_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
+
+ def test_bloom_token_classification_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
+
+ def test_bloom_gradient_checkpointing(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
+
+ def test_bloom_weight_initialization(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = BloomModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ @slow
+ @require_torch_gpu
+ def test_simple_generation(self):
+ # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
+ # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
+ # As we leave the default value (True) for allow_fp16_reduced_precision_reduction , the tests failed when running in half-precision with smaller models (350m)
+ # Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms
+ # This discrepancy is observed only when using small models and seems to be stable for larger models.
+ # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models.
+
+ # Here is a summary of an ablation study of our observations
+ # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a"
+ # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
+ # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS
+ # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS
+ # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL
+
+ # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love"
+ # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False)
+ # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS
+ # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
+
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
+
+ input_sentence = "I enjoy walking with my cute dog"
+ # This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU
+ EXPECTED_OUTPUT = (
+ "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very "
+ "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I"
+ )
+
+ input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
+ greedy_output = model.generate(input_ids.cuda(), max_length=50)
+
+ self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+
+ self.assertEqual(
+ tokenizer.decode(greedy_output[0], skip_special_tokens=True),
+ tokenizer.decode(greedy_output[1], skip_special_tokens=True),
+ )
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation_padd(self):
+
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "Hello my name is"]
+ input_sentence_without_pad = "Hello my name is"
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ input_ids_without_pad = tokenizer.encode(input_sentence_without_pad, return_tensors="pt")
+
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+ greedy_output_without_pad = model.generate(input_ids_without_pad.cuda(), max_length=50, do_sample=False)
+
+ # test token values
+ self.assertEqual(greedy_output[-1, 3:].tolist(), greedy_output_without_pad[0, :-3].tolist())
+
+ # test reconstructions
+ self.assertEqual(
+ tokenizer.decode(greedy_output[-1, 3:], skip_special_tokens=True),
+ tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True),
+ )
+
+
+@require_torch
+class BloomEmbeddingTest(unittest.TestCase):
+ """
+ The goal here is to compare the embeddings generated by the model trained
+ using Megatron-LM with the one from the transformers library, with a small GPT2-like model
+ to ensure that the conversion from Megatron-LM to transformers has been done successfully.
+ The script compares the logits of the embedding layer and the transformer layers.
+
+ WARNING: It is expected that these logits will not have exactly the same statistics when running
+ the code on CPU or GPU. For more info, please visit:
+ - https://github.com/pytorch/pytorch/issues/76052#issuecomment-1103193548
+ - https://discuss.pytorch.org/t/reproducibility-issue-between-intel-and-amd-cpus/144779/9
+
+
+ You need to install tokenizers following this readme:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ Tokenizer used during training:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ # TODO change the script (or just add skip) when building the env with tokenizers 0.12.0
+ """
+
+ def setUp(self):
+ super().setUp()
+ self.path_bigscience_model = "bigscience/bigscience-small-testing"
+
+ @require_torch
+ def test_embeddings(self):
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, torch_dtype="auto") # load in fp32
+ model.eval()
+
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN = {
+ 3478: 0.0002307891845703125,
+ 368: -0.000568389892578125,
+ 109586: -0.0003910064697265625,
+ 35433: -0.000194549560546875,
+ 2: 0.0004138946533203125,
+ 77: 0.000659942626953125,
+ 132619: -0.00031280517578125,
+ 2175: 0.000457763671875,
+ 23714: 0.000263214111328125,
+ 73173: -0.000286102294921875,
+ 144252: 0.00052642822265625,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM = {"value": 0.08203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN = {
+ 132619: -0.00031256675720214844,
+ 3478: 0.00023090839385986328,
+ 368: -0.0005702972412109375,
+ 109586: -0.00039124488830566406,
+ 35433: -0.000194549560546875,
+ 2: 0.0004146099090576172,
+ 2175: 0.0004572868347167969,
+ 23714: 0.00026416778564453125,
+ 73173: -0.0002865791320800781,
+ 144252: 0.0005254745483398438,
+ 77: 0.0006618499755859375,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_SUM = {"value": 0.0821533203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN = {
+ 132619: -0.00031267106533050537,
+ 3478: 0.00023087859153747559,
+ 368: -0.0005701072514057159,
+ 109586: -0.0003911703824996948,
+ 35433: -0.0001944899559020996,
+ 2: 0.0004146844148635864,
+ 2175: 0.00045740045607089996,
+ 23714: 0.0002641640603542328,
+ 73173: -0.0002864748239517212,
+ 144252: 0.0005256589502096176,
+ 77: 0.0006617321632802486,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_SUM = {"value": 0.08217757940292358}
+
+ TEST_EMBEDDINGS = {
+ "torch.bfloat16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM,
+ },
+ "torch.float32": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_16_SUM,
+ },
+ }
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ EMBEDDINGS_DS_AFTER_LN_MEAN = {
+ 3478: -6.580352783203125e-05,
+ 368: 0.0001316070556640625,
+ 109586: -0.00030517578125,
+ 35433: 4.00543212890625e-05,
+ 2: -7.2479248046875e-05,
+ 77: -8.96453857421875e-05,
+ 132619: 0.0001583099365234375,
+ 2175: 2.1219253540039062e-05,
+ 23714: -0.000247955322265625,
+ 73173: -0.00021839141845703125,
+ 144252: -0.0001430511474609375,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MIN = {
+ 3478: -1.6953125,
+ 368: -1.6875,
+ 109586: -1.6875,
+ 35433: -2.125,
+ 2: -1.390625,
+ 77: -1.5390625,
+ 132619: -1.875,
+ 2175: -1.4609375,
+ 23714: -2.296875,
+ 73173: -1.3515625,
+ 144252: -1.78125,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MAX = {
+ 3478: 2.265625,
+ 368: 2.28125,
+ 109586: 1.953125,
+ 35433: 1.90625,
+ 2: 2.703125,
+ 77: 2.828125,
+ 132619: 1.65625,
+ 2175: 2.015625,
+ 23714: 2.234375,
+ 73173: 2.171875,
+ 144252: 1.828125,
+ }
+
+ EMBEDDINGS_DS_AFTER_LN = {
+ "mean": EMBEDDINGS_DS_AFTER_LN_MEAN,
+ "min": EMBEDDINGS_DS_AFTER_LN_MIN,
+ "max": EMBEDDINGS_DS_AFTER_LN_MAX,
+ }
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+ with torch.no_grad():
+ embeddings = model.transformer.word_embeddings(tensor_ids)
+ embeddings_ln = model.transformer.word_embeddings_layernorm(embeddings) #
+ # first check the embeddings before LN
+ output_dict = {"min": {}, "max": {}, "mean": {}, "sum": {"value": embeddings.sum().item()}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict["min"][idx] = embeddings.min(dim=-1).values[0][i].item()
+ output_dict["max"][idx] = embeddings.max(dim=-1).values[0][i].item()
+ output_dict["mean"][idx] = embeddings.mean(dim=-1)[0][i].item()
+
+ for key in TEST_EMBEDDINGS[str(model.dtype)].keys():
+ self.assertDictEqual(TEST_EMBEDDINGS[str(model.dtype)][key], output_dict[key])
+
+ output_dict_norm = {"min": {}, "max": {}, "mean": {}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict_norm["min"][idx] = embeddings_ln.min(dim=-1).values[0][i].item()
+ output_dict_norm["max"][idx] = embeddings_ln.max(dim=-1).values[0][i].item()
+ output_dict_norm["mean"][idx] = embeddings_ln.mean(dim=-1)[0][i].item()
+
+ # This test does not pass when places = 2
+ for i, key in enumerate(output_dict_norm.keys()):
+ for j, idx in enumerate(output_dict[key].keys()):
+ self.assertAlmostEqual(EMBEDDINGS_DS_AFTER_LN[key][idx], output_dict_norm[key][idx], places=1)
+
+ @require_torch
+ def test_hidden_states_transformers(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ )
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_VALUE_LAST_LM = -4.3392181396484375e-05
+ MIN_MAX_DICT = {"min": -2.0625, "max": 2.75}
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+
+ with torch.no_grad():
+ logits = model(tensor_ids.to(torch_device))
+ output_dict = {
+ "min": logits.last_hidden_state.min(dim=-1).values[0][0].item(),
+ "max": logits.last_hidden_state.max(dim=-1).values[0][0].item(),
+ }
+
+ if cuda_available:
+ self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4)
+ else:
+ self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3)
+
+ self.assertDictEqual(MIN_MAX_DICT, output_dict)
+
+ @require_torch
+ def test_logits(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ ) # load in bf16
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_LOGITS_GPU_1 = -1.823902130126953e-05
+ MEAN_LOGITS_GPU_2 = 1.9431114196777344e-05
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS]).to(torch_device)
+ with torch.no_grad():
+ output = model(tensor_ids).logits
+
+ output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
+ if cuda_available:
+ self.assertEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1)
+ self.assertEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2)
+ else:
+ self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
+ self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
diff --git a/tests/models/bloom/test_tokenization_bloom.py b/tests/models/bloom/test_tokenization_bloom.py
new file mode 100644
index 000000000000..117240dbdae1
--- /dev/null
+++ b/tests/models/bloom/test_tokenization_bloom.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from datasets import load_dataset
+
+from transformers import BloomTokenizerFast
+from transformers.testing_utils import require_tokenizers
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+@require_tokenizers
+class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ slow_tokenizer_class = None
+ rust_tokenizer_class = BloomTokenizerFast
+ tokenizer_class = BloomTokenizerFast
+ test_rust_tokenizer = True
+ test_slow_tokenizer = False
+ from_pretrained_vocab_key = "tokenizer_file"
+ special_tokens_map = {"bos_token": "", "eos_token": "", "unk_token": "", "pad_token": ""}
+
+ def setUp(self):
+ super().setUp()
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/tokenizer")
+ tokenizer.save_pretrained(self.tmpdirname)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return BloomTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def test_encodings_from_sample_data(self):
+ """
+ Assert that the created tokens are the same than the hard-coded ones
+ """
+ tokenizer = self.get_rust_tokenizer()
+
+ INPUT_SENTENCES = ["The quick brown fox", "jumps over the lazy dog"]
+ TARGET_TOKENS = [[2175, 23714, 73173, 144252, 2], [77, 132619, 3478, 368, 109586, 35433, 2]]
+
+ computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"]
+ self.assertListEqual(TARGET_TOKENS, computed_tokens)
+
+ decoded_tokens = tokenizer.batch_decode(computed_tokens)
+ self.assertListEqual(decoded_tokens, INPUT_SENTENCES)
+
+ def test_padding(self, max_length=6):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ # tokenizer_r.pad_token = None # Hotfixing padding = None
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ]
+
+ # Simple input tests
+ try:
+ tokenizer_r.encode(s, max_length=max_length)
+ tokenizer_r.encode_plus(s, max_length=max_length)
+
+ tokenizer_r.batch_encode_plus(s2, max_length=max_length)
+ tokenizer_r.encode(p, max_length=max_length)
+ tokenizer_r.batch_encode_plus(p2, max_length=max_length)
+ except ValueError:
+ self.fail("Bloom Tokenizer should be able to deal with padding")
+
+ tokenizer_r.pad_token = None # Hotfixing padding = None
+ self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ s2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ p2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ def test_encodings_from_xnli_dataset(self):
+ """
+ Tests the tokenizer downloaded from here:
+ - https://huggingface.co/bigscience/tokenizer/
+ """
+ tokenizer = self.get_rust_tokenizer()
+ ds = load_dataset("xnli", "all_languages", split="test", streaming=True)
+
+ sample_data = next(iter(ds))["premise"] # pick up one data
+ input_text = list(sample_data.values())
+
+ output_tokens = list(map(tokenizer.encode, input_text))
+ predicted_text = list(map(lambda x: tokenizer.decode(x, clean_up_tokenization_spaces=False), output_tokens))
+ self.assertListEqual(predicted_text, input_text)
+
+ def test_pretrained_model_lists(self):
+ # The test has to be overriden because BLOOM uses ALiBi positional embeddings that does not have
+ # any sequence length constraints. This test of the parent class will fail since it relies on the
+ # maximum sequence length of the positoonal embeddings.
+ self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
+ self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
diff --git a/tests/models/byt5/test_tokenization_byt5.py b/tests/models/byt5/test_tokenization_byt5.py
index 70cfa40ef919..85057c5278bb 100644
--- a/tests/models/byt5/test_tokenization_byt5.py
+++ b/tests/models/byt5/test_tokenization_byt5.py
@@ -152,10 +152,9 @@ def test_max_length_integration(self):
"Summary of the text.",
"Another summary.",
]
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(
- tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
- )
+ targets = tokenizer(
+ text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
+ )
self.assertEqual(32, targets["input_ids"].shape[1])
def test_eos_in_input(self):
@@ -167,12 +166,10 @@ def test_eos_in_input(self):
expected_tgt_tokens = [86, 120, 112, 112, 100, 117, 124, 35, 114, 105, 35, 119, 107, 104, 35, 119, 104, 123, 119, 49, 35, 1]
# fmt: on
- batch = tokenizer(src_text)
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(tgt_text)
+ batch = tokenizer(src_text, text_target=tgt_text)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
- self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
+ self.assertEqual(expected_tgt_tokens, batch["labels"][0])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
def test_save_and_load_tokenizer(self):
diff --git a/tests/models/canine/test_modeling_canine.py b/tests/models/canine/test_modeling_canine.py
index 483dd095a18b..a4d13f0efab6 100644
--- a/tests/models/canine/test_modeling_canine.py
+++ b/tests/models/canine/test_modeling_canine.py
@@ -378,7 +378,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/models/canine/test_tokenization_canine.py b/tests/models/canine/test_tokenization_canine.py
index 0e016d523b5c..6ae27082cceb 100644
--- a/tests/models/canine/test_tokenization_canine.py
+++ b/tests/models/canine/test_tokenization_canine.py
@@ -80,8 +80,9 @@ def test_max_length_integration(self):
"What's the weater?",
"It's about 25 degrees.",
]
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt")
+ targets = tokenizer(
+ text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt"
+ )
self.assertEqual(32, targets["input_ids"].shape[1])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
diff --git a/tests/models/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py
index a3f0817ea0b2..8f36a65ae2d5 100644
--- a/tests/models/clip/test_feature_extraction_clip.py
+++ b/tests/models/clip/test_feature_extraction_clip.py
@@ -49,6 +49,7 @@ def __init__(
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
+ do_convert_rgb=True,
):
self.parent = parent
self.batch_size = batch_size
@@ -63,6 +64,7 @@ def __init__(
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
def prepare_feat_extract_dict(self):
return {
@@ -73,6 +75,7 @@ def prepare_feat_extract_dict(self):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
}
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
@@ -128,6 +131,7 @@ def test_feat_extract_properties(self):
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_batch_feature(self):
pass
@@ -227,3 +231,64 @@ def test_call_pytorch(self):
self.feature_extract_tester.crop_size,
),
)
+
+
+@require_torch
+@require_vision
+class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
+ self.expected_encoded_image_num_channels = 3
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil_four_channels(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.expected_encoded_image_num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.expected_encoded_image_num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py
index 7ae1146159e6..ab05f9adf1e8 100644
--- a/tests/models/clip/test_modeling_clip.py
+++ b/tests/models/clip/test_modeling_clip.py
@@ -100,6 +100,10 @@ def __init__(
self.initializer_range = initializer_range
self.scope = scope
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
@@ -148,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
-
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
@@ -160,8 +164,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="CLIP does not use inputs_embeds")
def test_inputs_embeds(self):
- # CLIP does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -189,114 +193,17 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- image_size = (self.model_tester.image_size, self.model_tester.image_size)
- patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_len = num_patches + 1
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # CLIP has a different seq_length
- image_size = (self.model_tester.image_size, self.model_tester.image_size)
- patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_length = num_patches + 1
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
- # skip this test as CLIPVisionModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
- # skip this test as CLIPVisionModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@@ -396,6 +303,7 @@ def prepare_config_and_inputs_for_common(self):
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_head_masking = False
@@ -416,17 +324,15 @@ def test_training(self):
def test_training_gradient_checkpointing(self):
pass
+ @unittest.skip(reason="CLIP does not use inputs_embeds")
def test_inputs_embeds(self):
- # CLIP does not use inputs_embeds
pass
- # skip this test as CLIPTextModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
- # skip this test as CLIPTextModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@@ -483,6 +389,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
+ fx_compatible = True
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
@@ -495,19 +402,19 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- # hidden_states are tested in individual model tests
+ @unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
- # input_embeds are tested in individual model tests
+ @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
- # tested in individual model tests
+ @unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
- # CLIPModel does not have input/output embeddings
+ @unittest.skip(reason="CLIPModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
diff --git a/tests/models/clip/test_modeling_tf_clip.py b/tests/models/clip/test_modeling_tf_clip.py
index 797d5b73b349..05b4c7920ebd 100644
--- a/tests/models/clip/test_modeling_tf_clip.py
+++ b/tests/models/clip/test_modeling_tf_clip.py
@@ -606,11 +606,21 @@ def test_model_from_pretrained(self):
model = TFCLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
+ @slow
+ def test_saved_model_creation(self):
+ pass
+
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation_extended(self):
pass
+ @unittest.skip(reason="`saved_model` doesn't work with nested outputs so no preparation happens.")
+ @slow
+ def test_prepare_serving_output(self):
+ pass
+
# We will verify our results on an image of cute cats
def prepare_img():
diff --git a/tests/models/codegen/__init__.py b/tests/models/codegen/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/codegen/test_modeling_codegen.py b/tests/models/codegen/test_modeling_codegen.py
new file mode 100644
index 000000000000..b59adc78181d
--- /dev/null
+++ b/tests/models/codegen/test_modeling_codegen.py
@@ -0,0 +1,554 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+import unittest
+
+from transformers import CodeGenConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST, AutoTokenizer, CodeGenForCausalLM, CodeGenModel
+
+
+class CodeGenModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=14,
+ seq_length=7,
+ is_training=True,
+ use_token_type_ids=True,
+ use_input_mask=True,
+ use_labels=True,
+ use_mc_token_ids=True,
+ vocab_size=256,
+ hidden_size=32,
+ rotary_dim=4,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_token_type_ids = use_token_type_ids
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.use_mc_token_ids = use_mc_token_ids
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.rotary_dim = rotary_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = None
+ self.bos_token_id = vocab_size - 1
+ self.eos_token_id = vocab_size - 1
+ self.pad_token_id = vocab_size - 1
+
+ def get_large_model_config(self):
+ return CodeGenConfig.from_pretrained("Salesforce/codegen-2B-mono")
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ mc_token_ids = None
+ if self.use_mc_token_ids:
+ mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
+
+ return (
+ config,
+ input_ids,
+ input_mask,
+ head_mask,
+ token_type_ids,
+ mc_token_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ )
+
+ def get_config(self):
+ return CodeGenConfig(
+ vocab_size=self.vocab_size,
+ n_embd=self.hidden_size,
+ n_layer=self.num_hidden_layers,
+ n_head=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ n_positions=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ use_cache=True,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ pad_token_id=self.pad_token_id,
+ rotary_dim=self.rotary_dim,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ input_mask,
+ head_mask,
+ token_type_ids,
+ mc_token_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ input_mask,
+ head_mask,
+ token_type_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ def create_and_check_codegen_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
+ model = CodeGenModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
+ result = model(input_ids, token_type_ids=token_type_ids)
+ result = model(input_ids)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(len(result.past_key_values), config.n_layer)
+
+ def create_and_check_codegen_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
+ model = CodeGenModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
+ outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
+ outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ output, past = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+ next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
+
+ output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_codegen_model_attention_mask_past(
+ self, config, input_ids, input_mask, head_mask, token_type_ids, *args
+ ):
+ model = CodeGenModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+ half_seq_length = self.seq_length // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_codegen_model_past_large_inputs(
+ self, config, input_ids, input_mask, head_mask, token_type_ids, *args
+ ):
+ model = CodeGenModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True)
+
+ output, past = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(
+ next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
+ )["last_hidden_state"]
+ output_from_past = model(
+ next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past
+ )["last_hidden_state"]
+ self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
+ model = CodeGenForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_forward_and_backwards(
+ self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
+ ):
+ model = CodeGenForCausalLM(config)
+ if gradient_checkpointing:
+ model.gradient_checkpointing_enable()
+ model.to(torch_device)
+
+ result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ result.loss.backward()
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+
+ (
+ config,
+ input_ids,
+ input_mask,
+ head_mask,
+ token_type_ids,
+ mc_token_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask}
+
+ return config, inputs_dict
+
+
+@require_torch
+class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (CodeGenModel, CodeGenForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (CodeGenForCausalLM,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_missing_keys = False
+ test_model_parallel = False
+ test_head_masking = False
+
+ # special case for DoubleHeads model
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = CodeGenModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=CodeGenConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_codegen_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_codegen_model(*config_and_inputs)
+
+ def test_codegen_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_codegen_model_past(*config_and_inputs)
+
+ def test_codegen_model_att_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_codegen_model_attention_mask_past(*config_and_inputs)
+
+ def test_codegen_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_codegen_model_past_large_inputs(*config_and_inputs)
+
+ def test_codegen_lm_head_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
+
+ def test_codegen_gradient_checkpointing(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
+
+ @slow
+ def test_batch_generation(self):
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+ model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
+ model.to(torch_device)
+
+ tokenizer.padding_side = "left"
+
+ # Define PAD Token = EOS Token = 50256
+ tokenizer.pad_token = tokenizer.eos_token
+ model.config.pad_token_id = model.config.eos_token_id
+
+ # use different length sentences to test batching
+ sentences = ["def hellow_world():", "def greet(name):"]
+
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
+ input_ids = inputs["input_ids"].to(torch_device)
+ token_type_ids = torch.cat(
+ [
+ input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
+ input_ids.new_full((input_ids.shape[0], 1), 500),
+ ],
+ dim=-1,
+ )
+
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=inputs["attention_mask"].to(torch_device),
+ )
+
+ outputs_tt = model.generate(
+ input_ids=input_ids,
+ attention_mask=inputs["attention_mask"].to(torch_device),
+ token_type_ids=token_type_ids,
+ )
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
+ inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ 'def hellow_world():\n print("Hello World")\n\nhellow_world()',
+ 'def greet(name):\n print(f"Hello {name}")\n\ng',
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
+ self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = CodeGenModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class CodeGenModelLanguageGenerationTest(unittest.TestCase):
+ @slow
+ def test_lm_generate_codegen(self):
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+ for checkpointing in [True, False]:
+ model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
+
+ if checkpointing:
+ model.gradient_checkpointing_enable()
+ else:
+ model.gradient_checkpointing_disable()
+ model.to(torch_device)
+
+ inputs = tokenizer("def hello_world():", return_tensors="pt").to(torch_device)
+ expected_output = 'def hello_world():\n print("Hello World")\n\nhello_world()\n\n'
+
+ output_ids = model.generate(**inputs, do_sample=False)
+ output_str = tokenizer.batch_decode(output_ids)[0]
+
+ self.assertEqual(output_str, expected_output)
+
+ @slow
+ def test_codegen_sample(self):
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+ model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
+ model.to(torch_device)
+
+ torch.manual_seed(0)
+ if torch_device == "cuda":
+ torch.cuda.manual_seed(0)
+
+ tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True)
+ input_ids = tokenized.input_ids.to(torch_device)
+ output_ids = model.generate(input_ids, do_sample=True)
+ output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
+
+ token_type_ids = tokenized.token_type_ids.to(torch_device)
+ output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
+ output_seq_tt = model.generate(
+ input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
+ )
+ output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
+ output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
+
+ if torch_device == "cuda":
+ EXPECTED_OUTPUT_STR = 'def hello_world():\n print("Hello World")\n return True\n\nresult ='
+ else:
+ EXPECTED_OUTPUT_STR = "def hello_world():\r\n print('Hello, World.')\r\n\r\n\r"
+
+ self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
+ self.assertTrue(
+ all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
+ ) # token_type_ids should change output
+
+ @slow
+ def test_codegen_sample_max_time(self):
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+ model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
+ model.to(torch_device)
+
+ torch.manual_seed(0)
+ tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
+ input_ids = tokenized.input_ids.to(torch_device)
+
+ MAX_TIME = 0.05
+
+ start = datetime.datetime.now()
+ model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
+ duration = datetime.datetime.now() - start
+ self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
+ self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))
+
+ start = datetime.datetime.now()
+ model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
+ duration = datetime.datetime.now() - start
+ self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
+ self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))
+
+ start = datetime.datetime.now()
+ model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
+ duration = datetime.datetime.now() - start
+ self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
+ self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))
+
+ start = datetime.datetime.now()
+ model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
+ duration = datetime.datetime.now() - start
+ self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
+ self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))
+
+ start = datetime.datetime.now()
+ model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
+ duration = datetime.datetime.now() - start
+ self.assertGreater(duration, datetime.timedelta(seconds=2 * MAX_TIME))
diff --git a/tests/models/codegen/test_tokenization_codegen.py b/tests/models/codegen/test_tokenization_codegen.py
new file mode 100644
index 000000000000..c15c8236b8da
--- /dev/null
+++ b/tests/models/codegen/test_tokenization_codegen.py
@@ -0,0 +1,265 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import os
+import re
+import unittest
+
+from transformers import CodeGenTokenizer, CodeGenTokenizerFast
+from transformers.models.codegen.tokenization_codegen import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_tokenizers, slow
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+@require_tokenizers
+class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = CodeGenTokenizer
+ rust_tokenizer_class = CodeGenTokenizerFast
+ test_rust_tokenizer = True
+ from_pretrained_kwargs = {"add_prefix_space": True}
+ test_seq2seq = False
+
+ def setUp(self):
+ super().setUp()
+
+ # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ "<|endoftext|>",
+ ]
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ def get_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return CodeGenTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return CodeGenTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "lower newer"
+ output_text = "lower newer"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = CodeGenTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
+ text = "lower newer"
+ bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
+ tokens = tokenizer.tokenize(text, add_prefix_space=True)
+ self.assertListEqual(tokens, bpe_tokens)
+
+ input_tokens = tokens + [tokenizer.unk_token]
+ input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
+
+ sequence = "lower newer"
+
+ # Testing tokenization
+ tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ # Testing conversion to ids without special tokens
+ ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ # Testing conversion to ids with special tokens
+ rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
+ ids = tokenizer.encode(sequence, add_prefix_space=True)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ # Testing the unknown token
+ input_tokens = tokens + [rust_tokenizer.unk_token]
+ input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
+ self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+
+ def test_pretokenized_inputs(self, *args, **kwargs):
+ # It's very difficult to mix/test pretokenization with byte-level
+ # And get both CodeGen and Roberta to work at the same time (mostly an issue of adding a space before the string)
+ pass
+
+ def test_padding(self, max_length=15):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ]
+
+ # Simple input tests
+ self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ s2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ p2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ def test_padding_if_pad_token_set_slow(self):
+ tokenizer = CodeGenTokenizer.from_pretrained(self.tmpdirname, pad_token="")
+
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input looooooooong", "This is a simple input"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input loooooong", "This is a simple input"),
+ ("This is a simple pair loooooong", "This is a simple pair"),
+ ]
+
+ pad_token_id = tokenizer.pad_token_id
+
+ out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np")
+ out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np")
+ out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np")
+ out_p2 = tokenizer(p2, padding=True, truncate=True, return_tensors="np")
+
+ # s
+ # test single string max_length padding
+ self.assertEqual(out_s["input_ids"].shape[-1], 30)
+ self.assertTrue(pad_token_id in out_s["input_ids"])
+ self.assertTrue(0 in out_s["attention_mask"])
+
+ # s2
+ # test automatic padding
+ self.assertEqual(out_s2["input_ids"].shape[-1], 33)
+ # long slice doesn't have padding
+ self.assertFalse(pad_token_id in out_s2["input_ids"][0])
+ self.assertFalse(0 in out_s2["attention_mask"][0])
+ # short slice does have padding
+ self.assertTrue(pad_token_id in out_s2["input_ids"][1])
+ self.assertTrue(0 in out_s2["attention_mask"][1])
+
+ # p
+ # test single pair max_length padding
+ self.assertEqual(out_p["input_ids"].shape[-1], 60)
+ self.assertTrue(pad_token_id in out_p["input_ids"])
+ self.assertTrue(0 in out_p["attention_mask"])
+
+ # p2
+ # test automatic padding pair
+ self.assertEqual(out_p2["input_ids"].shape[-1], 52)
+ # long slice pair doesn't have padding
+ self.assertFalse(pad_token_id in out_p2["input_ids"][0])
+ self.assertFalse(0 in out_p2["attention_mask"][0])
+ # short slice pair does have padding
+ self.assertTrue(pad_token_id in out_p2["input_ids"][1])
+ self.assertTrue(0 in out_p2["attention_mask"][1])
+
+ def test_add_bos_token_slow(self):
+ bos_token = "$$$"
+ tokenizer = CodeGenTokenizer.from_pretrained(self.tmpdirname, bos_token=bos_token, add_bos_token=True)
+
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+
+ bos_token_id = tokenizer.bos_token_id
+
+ out_s = tokenizer(s)
+ out_s2 = tokenizer(s2)
+
+ self.assertEqual(out_s.input_ids[0], bos_token_id)
+ self.assertTrue(all(o[0] == bos_token_id for o in out_s2.input_ids))
+
+ decode_s = tokenizer.decode(out_s.input_ids)
+ decode_s2 = tokenizer.batch_decode(out_s2.input_ids)
+
+ self.assertEqual(decode_s.split()[0], bos_token)
+ self.assertTrue(all(d.split()[0] == bos_token for d in decode_s2))
+
+ @slow
+ def test_truncation(self):
+ tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+
+ text = "\nif len_a > len_b:\n result = a\nelse:\n result = b\n\n\n\n#"
+ expected_trucated_text = "\nif len_a > len_b: result = a\nelse: result = b"
+
+ input_ids = tokenizer.encode(text)
+ truncation_pattern = ["^#", re.escape("<|endoftext|>"), "^'''", '^"""', "\n\n\n"]
+ decoded_text = tokenizer.decode(input_ids, truncate_before_pattern=truncation_pattern)
+ self.assertEqual(decoded_text, expected_trucated_text)
+
+ # tokenizer has no padding token
+ def test_padding_different_model_input_name(self):
+ pass
diff --git a/tests/models/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py
index f12a21bfe64c..46ef3ce71709 100644
--- a/tests/models/convnext/test_modeling_convnext.py
+++ b/tests/models/convnext/test_modeling_convnext.py
@@ -158,6 +158,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="ConvNext does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="ConvNext does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/convnext/test_modeling_tf_convnext.py b/tests/models/convnext/test_modeling_tf_convnext.py
index 7b86a99fd435..bc84cd0a4000 100644
--- a/tests/models/convnext/test_modeling_tf_convnext.py
+++ b/tests/models/convnext/test_modeling_tf_convnext.py
@@ -174,6 +174,13 @@ def test_model(self):
def test_attention_outputs(self):
pass
+ @unittest.skipIf(
+ not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
+ reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
+ )
+ def test_dataset_conversion(self):
+ super().test_dataset_conversion()
+
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
@@ -219,7 +226,10 @@ def recursive_check(tuple_object, dict_object):
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
- msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/models/ctrl/test_modeling_ctrl.py b/tests/models/ctrl/test_modeling_ctrl.py
index 0256a5718b5e..ad6652f882d5 100644
--- a/tests/models/ctrl/test_modeling_ctrl.py
+++ b/tests/models/ctrl/test_modeling_ctrl.py
@@ -13,6 +13,7 @@
# limitations under the License.
+import gc
import unittest
from transformers import CTRLConfig, is_torch_available
@@ -181,6 +182,12 @@ def setUp(self):
self.model_tester = CTRLModelTester(self)
self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ torch.cuda.empty_cache()
+
def test_config(self):
self.config_tester.run_common_tests()
@@ -201,6 +208,12 @@ def test_model_from_pretrained(self):
@require_torch
class CTRLModelLanguageGenerationTest(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ torch.cuda.empty_cache()
+
@slow
def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl")
diff --git a/tests/models/cvt/__init__.py b/tests/models/cvt/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/cvt/test_modeling_cvt.py b/tests/models/cvt/test_modeling_cvt.py
new file mode 100644
index 000000000000..b88f22d982be
--- /dev/null
+++ b/tests/models/cvt/test_modeling_cvt.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch CvT model. """
+
+
+import inspect
+import unittest
+from math import floor
+
+from transformers import CvtConfig
+from transformers.file_utils import cached_property, is_torch_available, is_vision_available
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import CvtForImageClassification, CvtModel
+ from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class CvtConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "embed_dim"))
+ self.parent.assertTrue(hasattr(config, "num_heads"))
+
+
+class CvtModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=64,
+ num_channels=3,
+ embed_dim=[16, 48, 96],
+ num_heads=[1, 3, 6],
+ depth=[1, 2, 10],
+ patch_sizes=[7, 3, 3],
+ patch_stride=[4, 2, 2],
+ patch_padding=[2, 1, 1],
+ stride_kv=[2, 2, 2],
+ cls_token=[False, False, True],
+ attention_drop_rate=[0.0, 0.0, 0.0],
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ is_training=True,
+ use_labels=True,
+ num_labels=2, # Check
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_sizes = patch_sizes
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.stride_kv = stride_kv
+ self.depth = depth
+ self.cls_token = cls_token
+ self.attention_drop_rate = attention_drop_rate
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return CvtConfig(
+ image_size=self.image_size,
+ num_labels=self.num_labels,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ patch_sizes=self.patch_sizes,
+ patch_padding=self.patch_padding,
+ patch_stride=self.patch_stride,
+ stride_kv=self.stride_kv,
+ depth=self.depth,
+ cls_token=self.cls_token,
+ attention_drop_rate=self.attention_drop_rate,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = CvtModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ image_size = (self.image_size, self.image_size)
+ height, width = image_size[0], image_size[1]
+ for i in range(len(self.depth)):
+ height = floor(((height + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1)
+ width = floor(((width + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dim[-1], height, width))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = CvtForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class CvtModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Cvt does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = CvtModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="Cvt does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="Cvt does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Cvt does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = len(self.model_tester.depth)
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # verify the first hidden states (first block)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-3:]),
+ [
+ self.model_tester.embed_dim[0],
+ self.model_tester.image_size // 4,
+ self.model_tester.image_size // 4,
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = CvtModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class CvtModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return AutoFeatureExtractor.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = CvtForImageClassification.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([0.9285, 0.9015, -0.3150]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py
index 87885268b261..e3fb96097d84 100644
--- a/tests/models/data2vec/test_modeling_data2vec_audio.py
+++ b/tests/models/data2vec/test_modeling_data2vec_audio.py
@@ -535,7 +535,7 @@ def _mock_init_weights(self, module):
def test_mask_feature_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained(
- "facebook/data2vec-audio-base-960h", mask_feature_prob=0.2, mask_feature_length=2
+ "hf-internal-testing/tiny-random-data2vec-seq-class", mask_feature_prob=0.2, mask_feature_length=2
)
model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained(
@@ -554,7 +554,7 @@ def test_mask_feature_prob_ctc(self):
attention_mask=batch["attention_mask"].to(torch_device),
).logits
- self.assertEqual(logits.shape, (4, 299, 32))
+ self.assertEqual(logits.shape, (4, 1498, 32))
def test_mask_time_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained(
@@ -736,7 +736,8 @@ def test_inference_ctc_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with thousands of spectators were trivialities not worth thinking about",
"his instant of panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py
index 2dc9f1e45e52..a7974e8cbd98 100644
--- a/tests/models/data2vec/test_modeling_data2vec_vision.py
+++ b/tests/models/data2vec/test_modeling_data2vec_vision.py
@@ -20,7 +20,7 @@
from transformers import Data2VecVisionConfig
from transformers.models.auto import get_values
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -37,10 +37,7 @@
Data2VecVisionForSemanticSegmentation,
Data2VecVisionModel,
)
- from transformers.models.data2vec.modeling_data2vec_vision import (
- DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
- to_2tuple,
- )
+ from transformers.models.data2vec.modeling_data2vec_vision import DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -94,6 +91,10 @@ def __init__(
self.out_indices = out_indices
self.num_labels = num_labels
+ # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -131,9 +132,7 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ num_patches = (self.image_size // self.patch_size) ** 2
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
@@ -195,6 +194,13 @@ def test_inputs_embeds(self):
# Data2VecVision does not use inputs_embeds
pass
+ @require_torch_multi_gpu
+ @unittest.skip(
+ reason="Data2VecVision has some layers using `add_module` which doesn't work well with `nn.DataParallel`"
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -286,108 +292,9 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_len = num_patches + 1
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
- chunk_length = getattr(self.model_tester, "chunk_length", None)
- if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
- encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # Data2VecVision has a different seq_length
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_length = num_patches + 1
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
+ # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
+ super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py
index 17b02d037c19..eb085af0d82b 100644
--- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py
+++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py
@@ -31,7 +31,11 @@
if is_tf_available():
import tensorflow as tf
- from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
+ from transformers import (
+ TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
+ TFData2VecVisionModel,
+ )
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
)
@@ -142,6 +146,18 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
+ config.num_labels = self.num_labels
+ model = TFData2VecVisionForSemanticSegmentation(config)
+ result = model(pixel_values, training=False)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
+ )
+ result = model(pixel_values, labels=pixel_labels)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
+ )
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels, pixel_labels = config_and_inputs
@@ -162,7 +178,11 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
- all_model_classes = (TFData2VecVisionModel, TFData2VecVisionForImageClassification) if is_tf_available() else ()
+ all_model_classes = (
+ (TFData2VecVisionModel, TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation)
+ if is_tf_available()
+ else ()
+ )
test_pruning = False
test_onnx = False
@@ -208,6 +228,14 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_for_image_segmentation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
+
+ @unittest.skip("Test was written for TF 1.x and isn't really relevant here")
+ def test_compile_tf_model(self):
+ pass
+
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
@@ -354,6 +382,10 @@ def test_keras_fit(self):
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
+ # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
+ super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
+
# Overriding this method since the base method won't be compatible with Data2VecVision.
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/deberta/test_modeling_deberta.py b/tests/models/deberta/test_modeling_deberta.py
index 8d2e2dd020aa..940a82db4398 100644
--- a/tests/models/deberta/test_modeling_deberta.py
+++ b/tests/models/deberta/test_modeling_deberta.py
@@ -130,6 +130,11 @@ def get_config(self):
pos_att_type=self.pos_att_type,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def check_loss_output(self, result):
self.parent.assertListEqual(list(result.loss.size()), [])
@@ -222,6 +227,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
+ fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
diff --git a/tests/models/deberta/test_tokenization_deberta.py b/tests/models/deberta/test_tokenization_deberta.py
index ca6574bc31cb..4aa53e13ff8d 100644
--- a/tests/models/deberta/test_tokenization_deberta.py
+++ b/tests/models/deberta/test_tokenization_deberta.py
@@ -126,7 +126,9 @@ def test_tokenizer_integration(self):
sequences = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
- "The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
+ "The first one is a factorized embedding parameterization. By decomposing the large vocabulary"
+ " embedding matrix into two small matrices, we separate the size of the hidden layers from the size of"
+ " vocabulary embedding.",
]
encoding = tokenizer(sequences, padding=True)
@@ -155,7 +157,9 @@ def test_tokenizer_integration(self):
expected_decoded_sequence = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
- "The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
+ "The first one is a factorized embedding parameterization. By decomposing the large vocabulary"
+ " embedding matrix into two small matrices, we separate the size of the hidden layers from the size of"
+ " vocabulary embedding.",
]
self.assertDictEqual(encoding.data, expected_encoding)
diff --git a/tests/models/deberta_v2/test_modeling_deberta_v2.py b/tests/models/deberta_v2/test_modeling_deberta_v2.py
index 17cdf3ea8f93..93436b901bb1 100644
--- a/tests/models/deberta_v2/test_modeling_deberta_v2.py
+++ b/tests/models/deberta_v2/test_modeling_deberta_v2.py
@@ -26,6 +26,7 @@
from transformers import (
DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
@@ -192,6 +193,23 @@ def create_and_check_deberta_for_question_answering(
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+ def create_and_check_deberta_for_multiple_choice(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = DebertaV2ForMultipleChoice(config=config)
+ model.to(torch_device)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_input_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ labels=choice_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -217,11 +235,13 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
DebertaV2ForQuestionAnswering,
+ DebertaV2ForMultipleChoice,
)
if is_torch_available()
else ()
)
+ fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
@@ -254,6 +274,10 @@ def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs)
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_deberta_for_multiple_choice(*config_and_inputs)
+
@slow
def test_model_from_pretrained(self):
for model_name in DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
diff --git a/tests/models/decision_transformer/test_modeling_decision_transformer.py b/tests/models/decision_transformer/test_modeling_decision_transformer.py
index 9124c64fa1d4..3ac89cf9bfc1 100644
--- a/tests/models/decision_transformer/test_modeling_decision_transformer.py
+++ b/tests/models/decision_transformer/test_modeling_decision_transformer.py
@@ -206,7 +206,9 @@ def test_autoregressive_prediction(self):
torch.manual_seed(0)
state = torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32) # env.reset()
- expected_outputs = torch.tensor([[0.2384, -0.2955, 0.8741], [0.6765, -0.0793, -0.1298]], device=torch_device)
+ expected_outputs = torch.tensor(
+ [[0.242793, -0.28693074, 0.8742613], [0.67815274, -0.08101085, -0.12952147]], device=torch_device
+ )
returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1)
states = state
diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py
index 4559fa0c7127..27f92c2d976a 100644
--- a/tests/models/deit/test_modeling_deit.py
+++ b/tests/models/deit/test_modeling_deit.py
@@ -131,6 +131,25 @@ def create_and_check_model(self, config, pixel_values, labels):
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = DeiTForMaskedImageModeling(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = DeiTForMaskedImageModeling(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DeiTForImageClassification(config)
@@ -139,6 +158,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = DeiTForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -208,6 +237,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py
new file mode 100644
index 000000000000..2a9638eda42e
--- /dev/null
+++ b/tests/models/deit/test_modeling_tf_deit.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TensorFlow DeiT model. """
+
+
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import DeiTConfig
+from transformers.testing_utils import require_tf, require_vision, slow
+from transformers.utils import cached_property, is_tf_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import (
+ TFDeiTForImageClassification,
+ TFDeiTForImageClassificationWithTeacher,
+ TFDeiTForMaskedImageModeling,
+ TFDeiTModel,
+ )
+ from transformers.models.deit.modeling_tf_deit import TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import DeiTFeatureExtractor
+
+
+class TFDeiTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ encoder_stride=2,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+ self.encoder_stride = encoder_stride
+
+ # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 2
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return DeiTConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFDeiTModel(config=config)
+ result = model(pixel_values)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = TFDeiTForMaskedImageModeling(config=config)
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = TFDeiTForMaskedImageModeling(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = TFDeiTForImageClassification(config)
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = TFDeiTForImageClassification(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_tf
+class TFDeiTModelTest(TFModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_tf_common.py, as DeiT does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (
+ TFDeiTModel,
+ TFDeiTForImageClassification,
+ TFDeiTForImageClassificationWithTeacher,
+ TFDeiTForMaskedImageModeling,
+ )
+ if is_tf_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_onnx = False
+
+ def setUp(self):
+ self.model_tester = TFDeiTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=DeiTConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="DeiT does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ # special case for DeiTForImageClassificationWithTeacher model
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
+ del inputs_dict["labels"]
+
+ return inputs_dict
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFDeiTModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+@require_vision
+class DeiTModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = tf.TensorShape((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
+
+ self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/dpr/test_tokenization_dpr.py b/tests/models/dpr/test_tokenization_dpr.py
index 2870e0bcf352..8ad2fea09c8b 100644
--- a/tests/models/dpr/test_tokenization_dpr.py
+++ b/tests/models/dpr/test_tokenization_dpr.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from transformers import (
DPRContextEncoderTokenizer,
DPRContextEncoderTokenizerFast,
diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
index b356b3ee0ba1..6980ed6cb26e 100644
--- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
@@ -351,6 +351,40 @@ def check_encoder_decoder_model_labels(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
+ def _check_output_with_attentions(
+ self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
+ ):
+ encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
+ self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
+
+ self.assertEqual(
+ encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
+ )
+
+ decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
+ num_decoder_layers = (
+ decoder_config.num_decoder_layers
+ if hasattr(decoder_config, "num_decoder_layers")
+ else decoder_config.num_hidden_layers
+ )
+ self.assertEqual(len(decoder_attentions), num_decoder_layers)
+
+ self.assertEqual(
+ decoder_attentions[0].shape[-3:],
+ (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
+ )
+
+ cross_attentions = outputs_encoder_decoder["cross_attentions"]
+ self.assertEqual(len(cross_attentions), num_decoder_layers)
+
+ cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
+ 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
+ )
+ self.assertEqual(
+ cross_attentions[0].shape[-3:],
+ (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
+ )
+
def check_encoder_decoder_model_output_attentions(
self,
config,
@@ -376,36 +410,58 @@ def check_encoder_decoder_model_output_attentions(
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
-
- encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
- self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
-
- self.assertEqual(
- encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
+ self._check_output_with_attentions(
+ outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
- decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
- num_decoder_layers = (
- decoder_config.num_decoder_layers
- if hasattr(decoder_config, "num_decoder_layers")
- else decoder_config.num_hidden_layers
- )
- self.assertEqual(len(decoder_attentions), num_decoder_layers)
+ def check_encoder_decoder_model_output_attentions_from_config(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ decoder_config,
+ decoder_input_ids,
+ decoder_attention_mask,
+ labels,
+ **kwargs
+ ):
+ # Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
+ # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
+ # from the inner models' configurations.
- self.assertEqual(
- decoder_attentions[0].shape[-3:],
- (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
+ decoder_input_ids = decoder_input_ids[:, :-1]
+ decoder_attention_mask = decoder_attention_mask[:, :-1]
+ encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
+ enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
+ enc_dec_model.config.output_attentions = True # model config -> won't work
+ enc_dec_model.to(torch_device)
+ outputs_encoder_decoder = enc_dec_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+ self.assertTrue(
+ all(
+ key not in outputs_encoder_decoder
+ for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
+ )
)
- cross_attentions = outputs_encoder_decoder["cross_attentions"]
- self.assertEqual(len(cross_attentions), num_decoder_layers)
-
- cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
- 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
+ config.output_attentions = True # inner model config -> will work
+ decoder_config.output_attentions = True
+ encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
+ enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
+ enc_dec_model.to(torch_device)
+ outputs_encoder_decoder = enc_dec_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
)
- self.assertEqual(
- cross_attentions[0].shape[-3:],
- (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
+ self._check_output_with_attentions(
+ outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
@@ -543,6 +599,10 @@ def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
+ def test_encoder_decoder_model_output_attentions_from_config(self):
+ input_ids_dict = self.prepare_config_and_inputs()
+ self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
+
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
diff --git a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
index 74eb59b4e016..d179d5f9d517 100644
--- a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
@@ -23,6 +23,7 @@
from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device
+from transformers.utils.generic import ModelOutput
from ...test_modeling_tf_common import ids_tensor
from ..bert.test_modeling_tf_bert import TFBertModelTester
@@ -254,31 +255,9 @@ def check_encoder_decoder_model_labels(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
- def check_encoder_decoder_model_output_attentions(
- self,
- config,
- input_ids,
- attention_mask,
- encoder_hidden_states,
- decoder_config,
- decoder_input_ids,
- decoder_attention_mask,
- **kwargs
+ def _check_output_with_attentions(
+ self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
):
- # make the decoder inputs a different shape from the encoder inputs to harden the test
- decoder_input_ids = decoder_input_ids[:, :-1]
- decoder_attention_mask = decoder_attention_mask[:, :-1]
- encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
- enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
- outputs_encoder_decoder = enc_dec_model(
- input_ids=input_ids,
- decoder_input_ids=decoder_input_ids,
- attention_mask=attention_mask,
- decoder_attention_mask=decoder_attention_mask,
- output_attentions=True,
- kwargs=kwargs,
- )
-
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
@@ -310,6 +289,83 @@ def check_encoder_decoder_model_output_attentions(
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
+ def check_encoder_decoder_model_output_attentions(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ decoder_config,
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs
+ ):
+ # make the decoder inputs a different shape from the encoder inputs to harden the test
+ decoder_input_ids = decoder_input_ids[:, :-1]
+ decoder_attention_mask = decoder_attention_mask[:, :-1]
+ encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
+ enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
+ outputs_encoder_decoder = enc_dec_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=True,
+ kwargs=kwargs,
+ )
+ self._check_output_with_attentions(
+ outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
+ )
+
+ def check_encoder_decoder_model_output_attentions_from_config(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ decoder_config,
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs
+ ):
+ # Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
+ # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
+ # from the inner models' configurations.
+
+ decoder_input_ids = decoder_input_ids[:, :-1]
+ decoder_attention_mask = decoder_attention_mask[:, :-1]
+ encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
+ enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
+ enc_dec_model.config.output_attentions = True # model config -> won't work
+ outputs_encoder_decoder = enc_dec_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ kwargs=kwargs,
+ )
+ self.assertTrue(
+ all(
+ key not in outputs_encoder_decoder
+ for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
+ )
+ )
+
+ config.output_attentions = True # inner model config -> will work
+ decoder_config.output_attentions = True
+ encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
+ enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
+ outputs_encoder_decoder = enc_dec_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ kwargs=kwargs,
+ )
+ self._check_output_with_attentions(
+ outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
+ )
+
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -326,31 +382,145 @@ def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config
)
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
- def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
+ """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
- pt_model.to(torch_device)
- pt_model.eval()
+ Args:
+ model_class: The class of the model that is currently testing. For example, `TFBertModel`,
+ TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
+ error messages.
+ name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
+ attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
+ being a named field in the output.
+ """
+
+ self.assertEqual(type(name), str)
+ if attributes is not None:
+ self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
+
+ # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
+ if isinstance(tf_outputs, ModelOutput):
+ self.assertTrue(
+ isinstance(pt_outputs, ModelOutput),
+ f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
+ )
- # prepare inputs
- tf_inputs = inputs_dict
- pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
- if "labels" in pt_inputs:
- pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
+ tf_keys = [k for k, v in tf_outputs.items() if v is not None]
+ pt_keys = [k for k, v in pt_outputs.items() if v is not None]
+
+ self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
+
+ # convert to the case of `tuple`
+ # appending each key to the current (string) `names`
+ attributes = tuple([f"{name}.{k}" for k in tf_keys])
+ self.check_pt_tf_outputs(
+ tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
+ )
+
+ # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
+ elif type(tf_outputs) in [tuple, list]:
+ self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
+ self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
+
+ if attributes is not None:
+ # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
+ self.assertEqual(
+ len(attributes),
+ len(tf_outputs),
+ f"{name}: The tuple `names` should have the same length as `tf_outputs`",
+ )
+ else:
+ # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
+ attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
+
+ for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
+ self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
+
+ elif isinstance(tf_outputs, tf.Tensor):
+ self.assertTrue(
+ isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
+ )
+
+ tf_outputs = tf_outputs.numpy()
+ pt_outputs = pt_outputs.detach().to("cpu").numpy()
+
+ self.assertEqual(
+ tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
+ )
+
+ # deal with NumPy's scalars to make replacing nan values by 0 work.
+ if np.isscalar(tf_outputs):
+ tf_outputs = np.array([tf_outputs])
+ pt_outputs = np.array([pt_outputs])
+
+ tf_nans = np.isnan(tf_outputs)
+ pt_nans = np.isnan(pt_outputs)
+
+ pt_outputs[tf_nans] = 0
+ tf_outputs[tf_nans] = 0
+ pt_outputs[pt_nans] = 0
+ tf_outputs[pt_nans] = 0
+
+ max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
+ self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
+ else:
+ raise ValueError(
+ "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
+ )
+
+ def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
+
+ pt_inputs_dict = {}
+ for name, key in tf_inputs_dict.items():
+ if type(key) == bool:
+ pt_inputs_dict[name] = key
+ elif name == "input_values":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ elif name == "pixel_values":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ elif name == "input_features":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ # other general float inputs
+ elif tf_inputs_dict[name].dtype.is_floating:
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ else:
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
+
+ return pt_inputs_dict
+
+ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
+
+ pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
- pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}
+ pt_inputs_dict = {
+ k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
+ }
+
+ # send pytorch model to the correct device
+ pt_model.to(torch_device)
+
+ # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
+ pt_model.eval()
with torch.no_grad():
- pt_outputs = pt_model(**pt_inputs).to_tuple()
+ pt_outputs = pt_model(**pt_inputs_dict)
+ tf_outputs = tf_model(tf_inputs_dict)
- tf_outputs = tf_model(**inputs_dict)
- if "loss" in tf_outputs:
- tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
- tf_outputs = tf_outputs.to_tuple()
- self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
+ # tf models returned loss is usually a tensor rather than a scalar.
+ # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
+ # Change it here to a scalar to match PyTorch models' loss
+ tf_loss = getattr(tf_outputs, "loss", None)
+ if tf_loss is not None:
+ tf_outputs.loss = tf.math.reduce_mean(tf_loss)
- for tf_output, pt_output in zip(tf_outputs, pt_outputs):
- self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
+ self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
+
+ def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict):
+ """Wrap `check_pt_tf_models` to further check PT -> TF again"""
+
+ self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@@ -363,18 +533,16 @@ def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
- tf_outputs_loaded = tf_model_loaded(**inputs_dict)
- if "loss" in tf_outputs_loaded:
- tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
- tf_outputs_loaded = tf_outputs_loaded.to_tuple()
- self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
+ self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
- for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
- self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
-
- def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
+ def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
+ """EncoderDecoderModel requires special way to cross load (PT -> TF)"""
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
+ # Output all for aggressive testing
+ encoder_decoder_config.output_hidden_states = True
+ # All models tested in this file have attentions
+ encoder_decoder_config.output_attentions = True
pt_model = EncoderDecoderModel(encoder_decoder_config)
@@ -388,11 +556,16 @@ def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
- self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
+ self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
- def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
+ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
+ """EncoderDecoderModel requires special way to cross load (TF -> PT)"""
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
+ # Output all for aggressive testing
+ encoder_decoder_config.output_hidden_states = True
+ # TODO: A generalizable way to determine this attribute
+ encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
@@ -401,7 +574,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
- _tf_model(**inputs_dict)
+ _tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
@@ -410,6 +583,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
+ tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@@ -421,7 +595,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
- self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
+ self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
@@ -451,6 +625,10 @@ def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
+ def test_encoder_decoder_model_output_attentions_from_config(self):
+ input_ids_dict = self.prepare_config_and_inputs()
+ self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
+
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
@@ -460,7 +638,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
- def test_pt_tf_equivalence(self):
+ def test_pt_tf_model_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
@@ -480,48 +658,58 @@ def test_pt_tf_equivalence(self):
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
- inputs_dict = config_inputs_dict
- # `encoder_hidden_states` is not used in model call/forward
- del inputs_dict["encoder_hidden_states"]
-
- inputs_dict_with_labels = copy.copy(inputs_dict)
- inputs_dict_with_labels["labels"] = labels
+ # Output all for aggressive testing
+ config.output_hidden_states = True
+ decoder_config.output_hidden_states = True
+ # All models tested in this file have attentions
+ config.output_attentions = True
+ decoder_config.output_attentions = True
- # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
- batch_size = inputs_dict["decoder_attention_mask"].shape[0]
- inputs_dict["decoder_attention_mask"] = tf.constant(
- np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
- )
+ tf_inputs_dict = config_inputs_dict
+ # `encoder_hidden_states` is not used in model call/forward
+ del tf_inputs_dict["encoder_hidden_states"]
+
+ # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
+ # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
+ for k in ["attention_mask", "decoder_attention_mask"]:
+ attention_mask = tf_inputs_dict[k]
+
+ # Make sure no all 0s attention masks - to avoid failure at this moment.
+ # Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
+ # TODO: remove this line once a fix regarding large negative values for attention mask is done.
+ attention_mask = tf.concat(
+ [tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
+ )
+ tf_inputs_dict[k] = attention_mask
- # TF models don't use the `use_cache` option and cache is not returned as a default.
- # So we disable `use_cache` here for PyTorch model.
- decoder_config.use_cache = False
+ tf_inputs_dict_with_labels = copy.copy(tf_inputs_dict)
+ tf_inputs_dict_with_labels["labels"] = labels
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
- # check without `enc_to_dec_proj` projection
+ # Original test: check without `labels` and without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
- self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
- self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
+ self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
+ self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
- # check equivalence with labels
- self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
- self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
+ # check with `labels`
+ self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
+ self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
- # # check `enc_to_dec_proj` work as expected
+ # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
- # self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
- # self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
+ # self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
+ # self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFEncoderDecoderModel(encoder_decoder_config)
- model(**inputs_dict)
+ model(tf_inputs_dict)
def test_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
@@ -554,6 +742,10 @@ def test_model_save_load_from_pretrained(self):
@require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
+ def setUp(self):
+ self.encoder_model_tester = TFBertModelTester(self, batch_size=13)
+ self.decoder_model_tester = TFBertModelTester(self, batch_size=13)
+
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
@@ -566,10 +758,8 @@ def get_encoder_decoder_model(self, config, decoder_config):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
- model_tester_encoder = TFBertModelTester(self, batch_size=13)
- model_tester_decoder = TFBertModelTester(self, batch_size=13)
- encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
- decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
+ decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@@ -652,6 +842,10 @@ def test_bert2bert_summarization(self):
@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
+ def setUp(self):
+ self.encoder_model_tester = TFBertModelTester(self, batch_size=13)
+ self.decoder_model_tester = TFGPT2ModelTester(self)
+
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
@@ -664,10 +858,8 @@ def get_encoder_decoder_model(self, config, decoder_config):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
- model_tester_encoder = TFBertModelTester(self, batch_size=13)
- model_tester_decoder = TFGPT2ModelTester(self)
- encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
- decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
+ decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@@ -744,6 +936,10 @@ def test_bert2gpt2_summarization(self):
@require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
+ def setUp(self):
+ self.encoder_model_tester = TFRobertaModelTester(self)
+ self.decoder_model_tester = TFRobertaModelTester(self)
+
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-roberta",
@@ -756,10 +952,8 @@ def get_encoder_decoder_model(self, config, decoder_config):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
- model_tester_encoder = TFRobertaModelTester(self)
- model_tester_decoder = TFRobertaModelTester(self)
- encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
- decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
+ decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
@@ -803,6 +997,10 @@ def prepare_config_and_inputs(self):
@require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
+ def setUp(self):
+ self.encoder_model_tester = TFRemBertModelTester(self)
+ self.decoder_model_tester = TFRemBertModelTester(self)
+
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-rembert",
@@ -815,10 +1013,8 @@ def get_encoder_decoder_model(self, config, decoder_config):
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
- model_tester_encoder = TFRemBertModelTester(self)
- model_tester_decoder = TFRemBertModelTester(self)
- encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
- decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ encoder_config_and_inputs = self.encoder_model_tester.prepare_config_and_inputs()
+ decoder_config_and_inputs = self.decoder_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
diff --git a/tests/models/flava/__init__.py b/tests/models/flava/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py
new file mode 100644
index 000000000000..793aa913aeb0
--- /dev/null
+++ b/tests/models/flava/test_feature_extraction_flava.py
@@ -0,0 +1,347 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaFeatureExtractor
+ from transformers.models.flava.feature_extraction_flava import (
+ FLAVA_CODEBOOK_MEAN,
+ FLAVA_CODEBOOK_STD,
+ FLAVA_IMAGE_MEAN,
+ FLAVA_IMAGE_STD,
+ )
+else:
+ FLAVA_IMAGE_MEAN = FLAVA_IMAGE_STD = FLAVA_CODEBOOK_MEAN = FLAVA_CODEBOOK_STD = None
+
+
+class FlavaFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=224,
+ do_center_crop=True,
+ crop_size=224,
+ resample=None,
+ do_normalize=True,
+ image_mean=FLAVA_IMAGE_MEAN,
+ image_std=FLAVA_IMAGE_STD,
+ input_size_patches=14,
+ total_mask_patches=75,
+ mask_group_max_patches=None,
+ mask_group_min_patches=16,
+ mask_group_min_aspect_ratio=0.3,
+ mask_group_max_aspect_ratio=None,
+ codebook_do_resize=True,
+ codebook_size=112,
+ codebook_resample=None,
+ codebook_do_center_crop=True,
+ codebook_crop_size=112,
+ codebook_do_map_pixels=True,
+ codebook_do_normalize=True,
+ codebook_image_mean=FLAVA_CODEBOOK_MEAN,
+ codebook_image_std=FLAVA_CODEBOOK_STD,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.do_resize = do_resize
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.size = size
+ self.resample = resample if resample is not None else Image.BICUBIC
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ self.codebook_do_resize = codebook_do_resize
+ self.codebook_size = codebook_size
+ self.codebook_resample = codebook_resample if codebook_resample is not None else Image.LANCZOS
+ self.codebook_do_center_crop = codebook_do_center_crop
+ self.codebook_crop_size = codebook_crop_size
+ self.codebook_do_map_pixels = codebook_do_map_pixels
+ self.codebook_do_normalize = codebook_do_normalize
+ self.codebook_image_mean = codebook_image_mean
+ self.codebook_image_std = codebook_image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "resample": self.resample,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "input_size_patches": self.input_size_patches,
+ "total_mask_patches": self.total_mask_patches,
+ "mask_group_max_patches": self.mask_group_max_patches,
+ "mask_group_min_patches": self.mask_group_min_patches,
+ "mask_group_min_aspect_ratio": self.mask_group_min_aspect_ratio,
+ "mask_group_max_aspect_ratio": self.mask_group_min_aspect_ratio,
+ "codebook_do_resize": self.codebook_do_resize,
+ "codebook_size": self.codebook_size,
+ "codebook_resample": self.codebook_resample,
+ "codebook_do_center_crop": self.codebook_do_center_crop,
+ "codebook_crop_size": self.codebook_crop_size,
+ "codebook_do_map_pixels": self.codebook_do_map_pixels,
+ "codebook_do_normalize": self.codebook_do_normalize,
+ "codebook_image_mean": self.codebook_image_mean,
+ "codebook_image_std": self.codebook_image_std,
+ }
+
+ def get_expected_image_size(self):
+ return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
+
+ def get_expected_mask_size(self):
+ return (
+ (self.input_size_patches, self.input_size_patches)
+ if not isinstance(self.input_size_patches, tuple)
+ else self.input_size_patches
+ )
+
+ def get_expected_codebook_image_size(self):
+ if not isinstance(self.codebook_size, tuple):
+ return (self.codebook_size, self.codebook_size)
+ else:
+ return self.codebook_size
+
+
+@require_torch
+@require_vision
+class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = FlavaFeatureExtractor if is_vision_available() else None
+ maxDiff = None
+
+ def setUp(self):
+ self.feature_extract_tester = FlavaFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "resample"))
+ self.assertTrue(hasattr(feature_extractor, "crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "masking_generator"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_size"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_resample"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_map_pixels"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_image_std"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def _test_call_framework(self, instance_class, prepare_kwargs):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs)
+ for image in image_inputs:
+ self.assertIsInstance(image, instance_class)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test masking
+ encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
+
+ def test_call_pytorch(self):
+ self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
+
+ def test_masking(self):
+ # Initialize feature_extractor
+ random.seed(1234)
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_image_mask=True, return_tensors="pt")
+ self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
+
+ def test_codebook_pixels(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size()
+ self.assertEqual(
+ encoded_images.codebook_pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_codebook_pixels=True, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size()
+ self.assertEqual(
+ encoded_images.codebook_pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py
new file mode 100644
index 000000000000..62b89e3977c3
--- /dev/null
+++ b/tests/models/flava/test_modeling_flava.py
@@ -0,0 +1,1228 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch FLAVA model. """
+
+
+import inspect
+import os
+import random
+import tempfile
+import unittest
+
+import numpy as np
+
+import requests
+from transformers import (
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+)
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import (
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaTextModel,
+ )
+ from transformers.models.flava.modeling_flava import (
+ FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ )
+else:
+ FlavaModel = None
+ FlavaForPreTraining = None
+ torch = {}
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaProcessor
+
+
+class FlavaImageModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ qkv_bias=True,
+ mask_token=True,
+ vocab_size=8192,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.mask_token = mask_token
+ self.vocab_size = vocab_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ num_patches = self.image_size // self.patch_size
+ bool_masked_pos = (
+ torch.rand((self.batch_size, num_patches, num_patches), device=pixel_values.device) < 0.9
+ ).long()
+ config = self.get_config()
+ return config, pixel_values, bool_masked_pos
+
+ def get_config(self):
+ return FlavaImageConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ qkv_bias=self.qkv_bias,
+ mask_token=self.mask_token,
+ vocab_size=self.vocab_size,
+ )
+
+ def create_and_check_model(self, config, pixel_values, bool_masked_pos):
+ model = FlavaImageModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values, bool_masked_pos)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.image_size, self.image_size)
+ patch_size = (self.patch_size, self.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, bool_masked_pos = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values, "bool_masked_pos": bool_masked_pos}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as FLAVA does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (FlavaImageModel,) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = FlavaImageModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaImageConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ # in FLAVA, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # FLAVA has a different seq_length
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_length = num_patches + 1
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ # skip this test as FlavaImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaImageModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaTextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ vocab_size=30522,
+ type_vocab_size=2,
+ max_position_embeddings=512,
+ position_embedding_type="absolute",
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ qkv_bias=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.position_embedding_type = position_embedding_type
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.pad_token_id = pad_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ token_type_ids = None
+
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask
+
+ def get_config(self):
+ return FlavaTextConfig(
+ vocab_size=self.vocab_size,
+ type_vocab_size=self.type_vocab_size,
+ max_position_embeddings=self.max_position_embeddings,
+ position_embedding_type=self.position_embedding_type,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ pad_token_id=self.pad_token_id,
+ qkv_bias=self.qkv_bias,
+ )
+
+ def create_and_check_model(self, config, input_ids, token_type_ids, input_mask):
+ model = FlavaTextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, input_mask = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaTextModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FlavaTextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaTextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FlavaTextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaTextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaTextModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaMultimodalModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=44,
+ use_input_mask=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ ce_ignore_index=-100,
+ use_cls_token=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.use_input_mask = use_input_mask
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.ce_ignore_index = ce_ignore_index
+ self.use_cls_token = use_cls_token
+
+ def prepare_config_and_inputs(self):
+ hidden_states = floats_tensor([self.batch_size, self.seq_length - 1, self.hidden_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ config = self.get_config()
+
+ return config, hidden_states, input_mask
+
+ def get_config(self):
+ return FlavaMultimodalConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ qkv_bias=self.qkv_bias,
+ use_cls_token=self.use_cls_token,
+ ce_ignore_index=self.ce_ignore_index,
+ )
+
+ def create_and_check_model(self, config, hidden_states, input_mask):
+ model = FlavaMultimodalModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(hidden_states, attention_mask=input_mask)
+ result = model(hidden_states)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, hidden_states, input_mask = config_and_inputs
+ inputs_dict = {"hidden_states": hidden_states, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaMultimodalModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FlavaMultimodalModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=FlavaMultimodalConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["hidden_states"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model_common_attributes(self):
+ # No embedding in multimodal model
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FlavaMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaMultimodalModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaImageCodebookTester:
+ def __init__(self, parent, batch_size=12, image_size=112, num_channels=3):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ return FlavaImageCodebookConfig()
+
+ def create_and_check_model(self, config, pixel_values):
+ model = FlavaImageCodebook(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.shape, (self.batch_size, config.vocab_size, self.image_size // 8, self.image_size // 8)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaImageCodebook,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ test_torchscript = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = FlavaImageCodebookTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaImageCodebookConfig, has_text_modality=False)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ @unittest.skip(reason="Flava does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ def test_model_common_attributes(self):
+ # No embedding in multimodal model
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_hidden_states_output(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ # no attentions
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ def test_model_outputs_equivalence(self):
+ pass
+
+ # skip this test as FlavaImageCodebook has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaImageCodebook has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaImageCodebook.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaModelTester:
+ model_class = FlavaModel
+
+ def __init__(
+ self,
+ parent,
+ is_training=True,
+ hidden_size=32,
+ projection_dim=32,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ ):
+ self.parent = parent
+ self.image_model_tester = FlavaImageModelTester(parent)
+ self.text_model_tester = FlavaTextModelTester(parent)
+ self.multimodal_model_tester = FlavaMultimodalModelTester(parent)
+ self.image_codebook_tester = FlavaImageCodebookTester(parent)
+ self.is_training = is_training
+ self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37)
+ self.hidden_size = hidden_size
+ self.projection_dim = projection_dim
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def prepare_config_and_inputs_for_common(self):
+ _, pixel_values, bool_masked_pos = self.image_model_tester.prepare_config_and_inputs()
+ _, input_ids, token_type_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+
+ config = self.get_config()
+
+ return config, {
+ "input_ids": input_ids,
+ "token_type_ids": token_type_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "bool_masked_pos": bool_masked_pos,
+ }
+
+ def get_config(self):
+ return FlavaConfig.from_configs(
+ self.image_model_tester.get_config(),
+ self.text_model_tester.get_config(),
+ self.multimodal_model_tester.get_config(),
+ self.image_codebook_tester.get_config(),
+ hidden_size=self.hidden_size,
+ projection_dim=self.projection_dim,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ )
+
+ def create_and_check_model(self, config, inputs):
+ self._test_model(config, inputs, test_image=True)
+ self._test_model(config, inputs, test_text=True)
+ self._test_model(config, inputs, test_image=True, test_text=True)
+
+ def _test_model(self, config, inputs, test_image=False, test_text=False):
+ model = self.model_class(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(
+ input_ids=inputs["input_ids"] if test_text else None,
+ attention_mask=inputs["attention_mask"] if test_text else None,
+ token_type_ids=inputs["token_type_ids"] if test_text else None,
+ pixel_values=inputs["pixel_values"] if test_image else None,
+ bool_masked_pos=inputs["bool_masked_pos"] if test_image else None,
+ )
+ image_size = (self.image_model_tester.image_size, self.image_model_tester.image_size)
+ patch_size = (self.image_model_tester.patch_size, self.image_model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ if test_image:
+ self.parent.assertEqual(
+ result.image_embeddings.shape,
+ (self.image_model_tester.batch_size, num_patches + 1, self.image_model_tester.hidden_size),
+ )
+ else:
+ self.parent.assertIsNone(result.image_embeddings)
+
+ if test_text:
+ self.parent.assertEqual(
+ result.text_embeddings.shape,
+ (
+ self.text_model_tester.batch_size,
+ self.text_model_tester.seq_length,
+ self.text_model_tester.hidden_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.text_embeddings)
+
+ if test_image and test_text:
+ self.parent.assertEqual(
+ result.multimodal_embeddings.shape,
+ (
+ self.multimodal_model_tester.batch_size,
+ self.text_model_tester.seq_length + num_patches + 2,
+ self.multimodal_model_tester.hidden_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.multimodal_embeddings)
+
+
+@require_torch
+class FlavaModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (FlavaModel,) if is_torch_available() else ()
+ class_for_tester = FlavaModelTester
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = self.class_for_tester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ # hidden_states are tested in individual model tests
+ def test_hidden_states_output(self):
+ pass
+
+ # input_embeds are tested in individual model tests
+ def test_inputs_embeds(self):
+ pass
+
+ # tested in individual model tests
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ # FlavaModel does not have input/output embeddings
+ def test_model_common_attributes(self):
+ pass
+
+ # override as the `logit_scale` parameter initilization is different for FLAVA
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # check if `logit_scale` is initilized as per the original implementation
+ if name == "logit_scale" or name == "flava.logit_scale":
+ self.assertAlmostEqual(
+ param.data.item(),
+ np.log(1 / 0.07),
+ delta=1e-3,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ configs_no_init.return_loss = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # FLAVA needs pixel_values
+
+ if "input_ids_masked" in inputs_dict:
+ # For pretraining
+ inputs = (input_ids, inputs_dict["input_ids_masked"], pixel_values)
+ else:
+ inputs = (input_ids, pixel_values)
+
+ traced_model = torch.jit.trace(model, inputs)
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ model.to(torch_device)
+ model.eval()
+
+ loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+ # Non persistent buffers won't be in original state dict
+ loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None)
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_load_image_text_config(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Save FlavaConfig and check if we can load FlavaImageConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ image_config = FlavaImageConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.image_config.to_dict(), image_config.to_dict())
+
+ # Save FlavaConfig and check if we can load FlavaTextConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ text_config = FlavaTextConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
+
+ # Save FlavaConfig and check if we can load FlavaMultimodalConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ multimodal_config = FlavaMultimodalConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.multimodal_config.to_dict(), multimodal_config.to_dict())
+
+ # overwrite from common since FlavaModel/TFFlavaModel return FLAVAOutput/TFFLAVAOutput
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaForPreTrainingTester(FlavaModelTester):
+ model_class = FlavaForPreTraining
+
+ def prepare_config_and_inputs_for_common(self):
+ _, pixel_values, bool_masked_pos = self.image_model_tester.prepare_config_and_inputs()
+ _, input_ids, token_type_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+
+ input_ids_masked = input_ids.detach().clone()
+ input_ids_masked[:, 1:3] = 100
+ mlm_labels = input_ids.detach().clone()
+ mlm_labels[:, :] = config.ce_ignore_index
+ mlm_labels[:, 1:3] = input_ids[:, 1:3]
+ mim_labels = torch.randint(
+ 0, self.image_model_tester.vocab_size, bool_masked_pos.size(), device=bool_masked_pos.device
+ ).long()
+ mim_labels[bool_masked_pos.ne(True)] = config.ce_ignore_index
+ itm_labels = torch.ones(mlm_labels.size(0), device=bool_masked_pos.device).long()
+
+ return config, {
+ "input_ids": input_ids,
+ "input_ids_masked": input_ids_masked,
+ "token_type_ids": token_type_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "bool_masked_pos": bool_masked_pos,
+ "mlm_labels": mlm_labels,
+ "mim_labels": mim_labels,
+ "itm_labels": itm_labels,
+ "return_loss": True,
+ }
+
+ def _test_model(self, config, inputs, test_image=False, test_text=False):
+ model = self.model_class(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(
+ input_ids=inputs["input_ids"] if test_text else None,
+ input_ids_masked=inputs["input_ids_masked"] if test_text else None,
+ attention_mask=inputs["attention_mask"] if test_text else None,
+ token_type_ids=inputs["token_type_ids"] if test_text else None,
+ pixel_values=inputs["pixel_values"] if test_image else None,
+ bool_masked_pos=inputs["bool_masked_pos"] if test_image else None,
+ mlm_labels=inputs["mlm_labels"],
+ mim_labels=inputs["mim_labels"],
+ itm_labels=inputs["itm_labels"],
+ return_loss=inputs["return_loss"],
+ )
+ image_size = (self.image_model_tester.image_size, self.image_model_tester.image_size)
+ patch_size = (self.image_model_tester.patch_size, self.image_model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ if test_image:
+ self.parent.assertEqual(
+ result.image_embeddings.shape,
+ (self.image_model_tester.batch_size, num_patches + 1, self.image_model_tester.hidden_size),
+ )
+ if not test_text:
+ self.parent.assertEqual(
+ result.loss_info.mim.dim(),
+ 0,
+ )
+ self.parent.assertEqual(
+ result.mim_logits.shape,
+ (inputs["bool_masked_pos"].sum().item(), self.image_model_tester.vocab_size),
+ )
+
+ else:
+ self.parent.assertIsNone(result.image_embeddings)
+
+ if test_text:
+ self.parent.assertEqual(
+ result.text_embeddings.shape,
+ (
+ self.text_model_tester.batch_size,
+ self.text_model_tester.seq_length,
+ self.text_model_tester.hidden_size,
+ ),
+ )
+ if not test_image:
+ self.parent.assertEqual(result.loss_info.mlm.dim(), 0)
+ self.parent.assertEqual(
+ result.mlm_logits.shape,
+ (
+ (inputs["mlm_labels"] != self.multimodal_model_tester.ce_ignore_index).sum().item(),
+ self.text_model_tester.vocab_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.text_embeddings)
+
+ if test_image and test_text:
+ self.parent.assertEqual(
+ result.multimodal_masked_embeddings.shape,
+ (
+ self.multimodal_model_tester.batch_size,
+ self.text_model_tester.seq_length + num_patches + 2,
+ self.multimodal_model_tester.hidden_size,
+ ),
+ )
+ self.parent.assertEqual(
+ result.itm_logits.shape,
+ (self.text_model_tester.batch_size, 2),
+ )
+ self.parent.assertEqual(
+ result.mmm_text_logits.shape,
+ (
+ (inputs["mlm_labels"] != self.multimodal_model_tester.ce_ignore_index).sum().item(),
+ self.text_model_tester.vocab_size,
+ ),
+ )
+ self.parent.assertEqual(
+ result.mmm_image_logits.shape,
+ (inputs["bool_masked_pos"].sum().item(), self.image_model_tester.vocab_size),
+ )
+ self.parent.assertEqual(
+ result.contrastive_logits_per_image.shape,
+ (self.image_model_tester.batch_size, self.text_model_tester.batch_size),
+ )
+ self.parent.assertEqual(
+ result.contrastive_logits_per_text.shape,
+ (self.text_model_tester.batch_size, self.image_model_tester.batch_size),
+ )
+
+ for item in [
+ result.loss_info.global_contrastive,
+ result.loss_info.itm,
+ result.loss_info.mmm_text,
+ result.loss_info.mmm_image,
+ ]:
+ self.parent.assertEqual(item.dim(), 0)
+
+ for item in [result.loss_info.mim, result.loss_info.mlm]:
+ self.parent.assertIsNone(item)
+
+ else:
+ self.parent.assertIsNone(result.multimodal_masked_embeddings)
+ for item in [
+ result.loss_info.global_contrastive,
+ result.loss_info.itm,
+ result.loss_info.mmm_text,
+ result.loss_info.mmm_image,
+ ]:
+ self.parent.assertIsNone(item)
+
+ self.parent.assertIsNone(result.multimodal_embeddings)
+
+
+@require_torch
+class FlavaForPreTrainingTest(FlavaModelTest):
+ all_model_classes = (FlavaForPreTraining,) if is_torch_available() else ()
+ class_for_tester = FlavaForPreTrainingTester
+ test_torchscript = False
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@require_vision
+@require_torch
+class FlavaModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = "facebook/flava-full"
+ model = FlavaModel.from_pretrained(model_name).to(torch_device)
+ processor = FlavaProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"],
+ images=[image, image],
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ ).to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs, return_dict=True)
+
+ # verify the embeddings
+ self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
+ self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
+ self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -3988.51367, places=4)
+
+
+@require_vision
+@require_torch
+class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = "facebook/flava-full"
+ model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device)
+ processor = FlavaProcessor.from_pretrained(model_name)
+ torch.manual_seed(1)
+ random.seed(1)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"],
+ images=[image, image],
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ return_codebook_pixels=True,
+ return_image_mask=True,
+ )
+ inputs["input_ids_masked"] = inputs["input_ids"].clone()
+ inputs["input_ids_masked"][0, 4:6] = 103
+ inputs["mlm_labels"] = inputs["input_ids"].clone()
+ inputs["mlm_labels"][:, :] = -100
+ inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
+ inputs = inputs.to(torch_device)
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ self.assertEqual(
+ outputs.contrastive_logits_per_image.shape,
+ torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
+ )
+ self.assertEqual(
+ outputs.contrastive_logits_per_text.shape,
+ torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
+ )
+
+ expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
+ self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
+ self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
+ self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
+ self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
diff --git a/tests/models/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py
new file mode 100644
index 000000000000..21cc84d5f299
--- /dev/null
+++ b/tests/models/flava/test_processor_flava.py
@@ -0,0 +1,234 @@
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import random
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import BertTokenizer, BertTokenizerFast
+from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_vision
+from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaFeatureExtractor, FlavaProcessor
+ from transformers.models.flava.feature_extraction_flava import (
+ FLAVA_CODEBOOK_MEAN,
+ FLAVA_CODEBOOK_STD,
+ FLAVA_IMAGE_MEAN,
+ FLAVA_IMAGE_STD,
+ )
+
+
+@require_vision
+class FlavaProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.tmpdirname = tempfile.mkdtemp()
+
+ # fmt: off
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"]
+ # fmt: on
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write("".join([x + "\n" for x in vocab_tokens]))
+
+ feature_extractor_map = {
+ "image_mean": FLAVA_IMAGE_MEAN,
+ "image_std": FLAVA_IMAGE_STD,
+ "do_normalize": True,
+ "do_resize": True,
+ "size": 224,
+ "do_center_crop": True,
+ "crop_size": 224,
+ "input_size_patches": 14,
+ "total_mask_patches": 75,
+ "mask_group_max_patches": None,
+ "mask_group_min_patches": 16,
+ "mask_group_min_aspect_ratio": 0.3,
+ "mask_group_max_aspect_ratio": None,
+ "codebook_do_resize": True,
+ "codebook_size": 112,
+ "codebook_resample": None,
+ "codebook_do_center_crop": True,
+ "codebook_crop_size": 112,
+ "codebook_do_map_pixels": True,
+ "codebook_do_normalize": True,
+ "codebook_image_mean": FLAVA_CODEBOOK_MEAN,
+ "codebook_image_std": FLAVA_CODEBOOK_STD,
+ }
+
+ self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
+ json.dump(feature_extractor_map, fp)
+
+ def get_tokenizer(self, **kwargs):
+ return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return FlavaFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def prepare_image_inputs(self):
+ """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
+ or a list of PyTorch tensors if one specifies torchify=True.
+ """
+
+ image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
+
+ image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+
+ return image_inputs
+
+ def test_save_load_pretrained_default(self):
+ tokenizer_slow = self.get_tokenizer()
+ tokenizer_fast = self.get_rust_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor_slow = FlavaProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor)
+ processor_slow.save_pretrained(self.tmpdirname)
+ processor_slow = FlavaProcessor.from_pretrained(self.tmpdirname, use_fast=False)
+
+ processor_fast = FlavaProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor)
+ processor_fast.save_pretrained(self.tmpdirname)
+ processor_fast = FlavaProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
+ self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
+ self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor_slow.feature_extractor, FlavaFeatureExtractor)
+ self.assertIsInstance(processor_fast.feature_extractor, FlavaFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = FlavaProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = FlavaProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, FlavaFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ image_input = self.prepare_image_inputs()
+
+ input_feat_extract = feature_extractor(image_input, return_tensors="np")
+ input_processor = processor(images=image_input, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ # With rest of the args
+ random.seed(1234)
+ input_feat_extract = feature_extractor(
+ image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
+ )
+ random.seed(1234)
+ input_processor = processor(
+ images=image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
+ )
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_processor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+
+ inputs = processor(text=input_str, images=image_input)
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
+
+ # add extra args
+ inputs = processor(text=input_str, images=image_input, return_codebook_pixels=True, return_image_mask=True)
+
+ self.assertListEqual(
+ list(inputs.keys()),
+ [
+ "input_ids",
+ "token_type_ids",
+ "attention_mask",
+ "pixel_values",
+ "codebook_pixel_values",
+ "bool_masked_pos",
+ ],
+ )
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py
index 0abf51fb5d75..974d7c2d4e5d 100644
--- a/tests/models/fnet/test_modeling_fnet.py
+++ b/tests/models/fnet/test_modeling_fnet.py
@@ -333,7 +333,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py
index 9e487b609aae..4cc4055a69f2 100644
--- a/tests/models/fsmt/test_modeling_fsmt.py
+++ b/tests/models/fsmt/test_modeling_fsmt.py
@@ -351,9 +351,10 @@ def test_prepare_fsmt_decoder_inputs(self):
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
- ignore = float("-inf")
+ causal_mask_dtype = torch.float32
+ ignore = torch.finfo(causal_mask_dtype).min
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
- config, input_ids, decoder_input_ids
+ config, input_ids, decoder_input_ids, causal_mask_dtype=causal_mask_dtype
)
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
diff --git a/tests/models/funnel/test_modeling_tf_funnel.py b/tests/models/funnel/test_modeling_tf_funnel.py
index 422985f7a6fb..faeb9a799510 100644
--- a/tests/models/funnel/test_modeling_tf_funnel.py
+++ b/tests/models/funnel/test_modeling_tf_funnel.py
@@ -17,7 +17,7 @@
import unittest
from transformers import FunnelConfig, is_tf_available
-from transformers.testing_utils import require_tf
+from transformers.testing_utils import require_tf, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@@ -371,8 +371,8 @@ def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_compile_tf_model(self):
@@ -407,6 +407,6 @@ def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
index c474f519d000..0960daff8360 100644
--- a/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -166,6 +166,11 @@ def get_config(
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py
index 9790b1c76626..b4752a155c34 100644
--- a/tests/models/gpt2/test_modeling_tf_gpt2.py
+++ b/tests/models/gpt2/test_modeling_tf_gpt2.py
@@ -16,7 +16,7 @@
import unittest
from transformers import GPT2Config, is_tf_available
-from transformers.testing_utils import require_tf, slow
+from transformers.testing_utils import require_tf, require_tf2onnx, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -294,21 +294,6 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
- def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
- config.eos_token_id = None
- config.max_length = 10
- model = TFGPT2LMHeadModel(config=config)
-
- # make sure there are no pad tokens in prompt
- input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
-
- generated = model.generate(input_ids)
-
- generate_xla = tf.function(model.generate, jit_compile=True)
- generated_xla = generate_xla(input_ids)
-
- self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
-
def create_and_check_gpt2_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
):
@@ -408,10 +393,6 @@ def test_gpt2_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
- def test_gpt2_xla_generate_fast(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_gpt2_xla_generate_fast(*config_and_inputs)
-
def test_gpt2_double_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
@@ -444,6 +425,32 @@ def test_model_from_pretrained(self):
model = TFGPT2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
+ # overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
+ # `batch_dims` > 0"
+ @require_tf2onnx
+ @slow
+ def test_onnx_runtime_optimize(self):
+ if not self.test_onnx:
+ return
+
+ import onnxruntime
+ import tf2onnx
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+
+ # Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
+ if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]:
+ continue
+
+ model = model_class(config)
+ model(model.dummy_inputs)
+
+ onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
+
+ onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
+
@require_tf
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
@@ -456,7 +463,7 @@ def test_lm_generate_greedy_distilgpt2_batch_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
@@ -465,12 +472,12 @@ def test_lm_generate_greedy_distilgpt2_batch_special(self):
"repetition_penalty": 1.3,
}
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
- "Yesterday was a very busy day for the first time since I started writing this post",
+ "Yesterday was a very interesting time for the world to see how much of this is",
]
self.assertListEqual(output_strings, expected_output_string)
@@ -483,7 +490,7 @@ def test_lm_generate_sample_distilgpt2_batch_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"do_sample": True,
@@ -498,13 +505,13 @@ def test_lm_generate_sample_distilgpt2_batch_special(self):
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
- "Today is a beautiful day and we will make you feel very hot/terrific in all",
- "Yesterday was another solid success as news coverage became standard American domestic television hit.",
+ "Today is a beautiful day and we will make you feel very hot/terrific in all your",
+ "Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
]
self.assertListEqual(output_strings, expected_output_string)
@@ -517,7 +524,7 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
@@ -526,37 +533,69 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
"num_beams": 2,
}
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and a great day for all of us.\n\nIām",
- "Yesterday was the first day of the year for the second time in a row,",
+ "Yesterday was the first time that a person has been arrested in the United States for",
]
self.assertListEqual(output_strings, expected_output_string)
+ @slow
+ def test_lm_generate_distilgpt2_left_padding(self):
+ """Tests that the generated text is the same, regarless of left padding"""
+ model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
+
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ generation_kwargs = {
+ "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
+ "no_repeat_ngram_size": 2,
+ "do_sample": False,
+ "repetition_penalty": 1.3,
+ }
+ expected_output_string = (
+ "Today is a beautiful day and I am so happy to be able take part in this amazing event."
+ )
+
+ sentences = ["Today is a beautiful day and"]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
+ # using default length
+ output_ids = model.generate(**input_ids, **generation_kwargs)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertEqual(output_strings[0], expected_output_string)
+
+ sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
+ # longer max length to capture the full length (remember: it is left padded)
+ output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertEqual(output_strings[0], expected_output_string)
+
@slow
def test_lm_generate_gpt2_greedy_xla(self):
- # TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
- # the underlying problem)
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
- sentences = ["The dog"]
+ sentences = ["The dog", "The flying machine"]
expected_output_strings = [
- "The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
+ "The dog was found in a field near the intersection of West and West Streets.\n\nThe",
+ "The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
- output_ids = model.generate(input_ids, do_sample=False)
+ output_ids = model.generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
xla_generate = tf.function(model.generate, jit_compile=True)
- output_ids = xla_generate(input_ids, do_sample=False)
+ output_ids = xla_generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
@@ -574,20 +613,48 @@ def test_lm_generate_gpt2_sample_xla(self):
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
- sentence = ["The dog"]
+ sentence = ["The dog", "The flying machine"]
expected_output_string = [
- "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
+ "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
+ " puppies",
+ "The flying machine was made by an artist who found it difficult to control it as it did not use",
]
expected_output_string_xla = [
- "The dog has been named in connection with the murder of a 20-year-old man in!"
+ "The dog has been named in connection with the murder of a 20-year-old man in",
+ "The flying machine is a new and improved system to operate and operate a new system and system "
+ "system system",
]
- input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
- output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
+ output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
- output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
+ output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
+
+ @slow
+ def test_lm_generate_gpt2_beam_search_xla(self):
+ model = TFGPT2LMHeadModel.from_pretrained("gpt2")
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ sentences = ["The dog", "The flying machine"]
+ expected_output_strings = [
+ "The dog was found in the backyard of a home in the 6500 block of South Main Street",
+ "The flying machine is a very powerful machine, but it's not a very powerful machine. It's",
+ ]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
+
+ output_ids = model.generate(**input_ids, do_sample=False, num_beams=2)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(output_strings, expected_output_strings)
+
+ xla_generate = tf.function(model.generate, jit_compile=True)
+ output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(output_strings, expected_output_strings)
diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py
index b14c113edbc2..d76bc75ccbd5 100644
--- a/tests/models/gpt2/test_tokenization_gpt2.py
+++ b/tests/models/gpt2/test_tokenization_gpt2.py
@@ -175,6 +175,78 @@ def test_padding(self, max_length=15):
padding="max_length",
)
+ def test_padding_if_pad_token_set_slow(self):
+ tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, pad_token="")
+
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input looooooooong", "This is a simple input"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input loooooong", "This is a simple input"),
+ ("This is a simple pair loooooong", "This is a simple pair"),
+ ]
+
+ pad_token_id = tokenizer.pad_token_id
+
+ out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np")
+ out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np")
+ out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np")
+ out_p2 = tokenizer(p2, padding=True, truncate=True, return_tensors="np")
+
+ # s
+ # test single string max_length padding
+ self.assertEqual(out_s["input_ids"].shape[-1], 30)
+ self.assertTrue(pad_token_id in out_s["input_ids"])
+ self.assertTrue(0 in out_s["attention_mask"])
+
+ # s2
+ # test automatic padding
+ self.assertEqual(out_s2["input_ids"].shape[-1], 33)
+ # long slice doesn't have padding
+ self.assertFalse(pad_token_id in out_s2["input_ids"][0])
+ self.assertFalse(0 in out_s2["attention_mask"][0])
+ # short slice does have padding
+ self.assertTrue(pad_token_id in out_s2["input_ids"][1])
+ self.assertTrue(0 in out_s2["attention_mask"][1])
+
+ # p
+ # test single pair max_length padding
+ self.assertEqual(out_p["input_ids"].shape[-1], 60)
+ self.assertTrue(pad_token_id in out_p["input_ids"])
+ self.assertTrue(0 in out_p["attention_mask"])
+
+ # p2
+ # test automatic padding pair
+ self.assertEqual(out_p2["input_ids"].shape[-1], 52)
+ # long slice pair doesn't have padding
+ self.assertFalse(pad_token_id in out_p2["input_ids"][0])
+ self.assertFalse(0 in out_p2["attention_mask"][0])
+ # short slice pair does have padding
+ self.assertTrue(pad_token_id in out_p2["input_ids"][1])
+ self.assertTrue(0 in out_p2["attention_mask"][1])
+
+ def test_add_bos_token_slow(self):
+ bos_token = "$$$"
+ tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, bos_token=bos_token, add_bos_token=True)
+
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+
+ bos_token_id = tokenizer.bos_token_id
+
+ out_s = tokenizer(s)
+ out_s2 = tokenizer(s2)
+
+ self.assertEqual(out_s.input_ids[0], bos_token_id)
+ self.assertTrue(all(o[0] == bos_token_id for o in out_s2.input_ids))
+
+ decode_s = tokenizer.decode(out_s.input_ids)
+ decode_s2 = tokenizer.batch_decode(out_s2.input_ids)
+
+ self.assertEqual(decode_s.split()[0], bos_token)
+ self.assertTrue(all(d.split()[0] == bos_token for d in decode_s2))
+
# tokenizer has no padding token
def test_padding_different_model_input_name(self):
pass
diff --git a/tests/models/gpt_neo/test_modeling_gpt_neo.py b/tests/models/gpt_neo/test_modeling_gpt_neo.py
index f8607cf1edb9..16a775e2731b 100644
--- a/tests/models/gpt_neo/test_modeling_gpt_neo.py
+++ b/tests/models/gpt_neo/test_modeling_gpt_neo.py
@@ -151,6 +151,11 @@ def get_config(self):
attention_types=self.attention_types,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/models/gpt_neox/__init__.py b/tests/models/gpt_neox/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py
new file mode 100644
index 000000000000..0435624f6f11
--- /dev/null
+++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py
@@ -0,0 +1,230 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch GPTNeoX model. """
+
+
+import unittest
+
+from transformers import GPTNeoXConfig, is_torch_available
+from transformers.testing_utils import require_torch, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import GPTNeoXForCausalLM, GPTNeoXModel
+
+
+class GPTNeoXModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_labels = None
+ if self.use_labels:
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask, token_labels
+
+ def get_config(self):
+ return GPTNeoXConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+
+ return config, input_ids, input_mask, token_labels
+
+ def create_and_check_model(self, config, input_ids, input_mask):
+ model = GPTNeoXModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ _ = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_model_as_decoder(self, config, input_ids, input_mask):
+ config.add_cross_attention = True
+ model = GPTNeoXModel(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels):
+ model = GPTNeoXForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
+ config.is_decoder = True
+ model = GPTNeoXForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True)
+ output_from_no_past = output_from_no_past["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, input_mask, token_labels = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
+ test_pruning = False
+ test_missing_keys = False
+ test_model_parallel = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = GPTNeoXModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(config, input_ids, input_mask)
+
+ def test_model_as_decoder(self):
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
+
+ def test_model_as_decoder_with_default_input_mask(self):
+ # This regression test was failing with PyTorch < 1.3
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
+
+ input_mask = None
+
+ self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
+
+ def test_decoder_model_past_large_inputs(self):
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
+
+ def test_model_for_causal_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
index 23ca46eb8280..b8b088d42f1e 100644
--- a/tests/models/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -155,6 +155,11 @@ def get_config(self):
rotary_dim=self.rotary_dim,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/models/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py
index 0d9af0b65087..ec6c15d3f6d6 100644
--- a/tests/models/gptj/test_modeling_tf_gptj.py
+++ b/tests/models/gptj/test_modeling_tf_gptj.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import datetime
import unittest
from transformers import AutoTokenizer, GPTJConfig, is_tf_available
@@ -48,6 +47,7 @@ def __init__(self, parent):
self.use_mc_token_ids = True
self.vocab_size = 99
self.hidden_size = 32
+ self.rotary_dim = 4
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
@@ -103,6 +103,7 @@ def prepare_config_and_inputs(self):
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
+ rotary_dim=self.rotary_dim,
return_dict=True,
)
@@ -359,10 +360,10 @@ def test_resize_token_embeddings(self):
@require_tf
+@tooslow
+# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
- @tooslow
def test_lm_generate_gptj(self):
- # Marked as @tooslow due to GPU OOM
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
@@ -372,74 +373,20 @@ def test_lm_generate_gptj(self):
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
- @tooslow
def test_gptj_sample(self):
- # Marked as @tooslow due to GPU OOM (issue #13676)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
- tf.random.set_seed(0)
- tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True)
- input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids
- output_ids = model.generate(input_ids, do_sample=True)
+ tokenized = tokenizer("Today is a nice day and", return_tensors="tf")
+ # forces the generation to happen on CPU, to avoid GPU-related quirks
+ with tf.device(":/CPU:0"):
+ output_ids = model.generate(**tokenized, do_sample=True, seed=[42, 0])
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
- output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
- output_seq_tt = model.generate(
- input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
- )
- output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
- output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
-
- EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"
-
+ EXPECTED_OUTPUT_STR = "Today is a nice day and Iām going to go for a walk. Iā"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
- self.assertTrue(
- all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
- ) # token_type_ids should change output
- @slow
- @unittest.skip(reason="TF generate currently has no time-based stopping criteria")
- def test_gptj_sample_max_time(self):
- tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
- model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True)
-
- input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids
-
- MAX_TIME = 0.5
-
- start = datetime.datetime.now()
- model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
- duration = datetime.datetime.now() - start
- self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
- self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
-
- start = datetime.datetime.now()
- model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
- duration = datetime.datetime.now() - start
- self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
- self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
-
- start = datetime.datetime.now()
- model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
- duration = datetime.datetime.now() - start
- self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
- self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
-
- start = datetime.datetime.now()
- model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
- duration = datetime.datetime.now() - start
- self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
- self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
-
- start = datetime.datetime.now()
- model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
- duration = datetime.datetime.now() - start
- self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
-
- @tooslow
- def test_batch_generation(self):
- # Marked as @tooslow due to GPU OOM
+ def _get_beam_search_test_objects(self):
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
@@ -454,42 +401,46 @@ def test_batch_generation(self):
"Hello, my dog is a little",
"Today, I",
]
+ expected_output_sentences = [
+ "Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia",
+ "Today, Iām going to be talking about a topic thatā",
+ ]
+ return model, tokenizer, sentences, expected_output_sentences
- inputs = tokenizer(sentences, return_tensors="tf", padding=True)
- input_ids = inputs["input_ids"]
- token_type_ids = tf.concat(
- [
- tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64),
- 500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64),
- ],
- axis=-1,
- )
+ def test_batch_beam_search(self):
+ # Confirms that we get the expected results with left-padded beam search
+ model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
- outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
- outputs_tt = model.generate(
- input_ids=input_ids,
- attention_mask=inputs["attention_mask"],
- token_type_ids=token_type_ids,
- )
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ outputs = model.generate(**inputs, do_sample=False, num_beams=2)
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ self.assertListEqual(expected_output_sentences, batch_out_sentence)
- inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
- output_non_padded = model.generate(input_ids=inputs_non_padded)
+ def test_batch_left_padding(self):
+ # Confirms that left-padding is working properly
+ model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="tf")
+ output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2)
num_paddings = (
- shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
+ shape_list(inputs_non_padded["input_ids"])[-1]
+ - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
+ )
+ inputs_padded = tokenizer(sentences[1], return_tensors="tf")
+ output_padded = model.generate(
+ **inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings
)
- inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
- output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
-
- batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+ self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence])
- expected_output_sentence = [
- "Hello, my dog is a little over a year old and has been diagnosed with a heart murmur",
- "Today, Iām going to share with you a few of my favorite",
- ]
- self.assertListEqual(expected_output_sentence, batch_out_sentence)
- self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
- self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
+ def test_xla_beam_search(self):
+ # Confirms that XLA is working properly
+ model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()
+
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ xla_generate = tf.function(model.generate, jit_compile=True)
+ outputs_xla = xla_generate(**inputs, do_sample=False, num_beams=2)
+ xla_sentence = tokenizer.batch_decode(outputs_xla, skip_special_tokens=True)
+ self.assertListEqual(expected_output_sentences, xla_sentence)
diff --git a/tests/models/groupvit/__init__.py b/tests/models/groupvit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/groupvit/test_modeling_groupvit.py b/tests/models/groupvit/test_modeling_groupvit.py
new file mode 100644
index 000000000000..bd6dbd3bc06f
--- /dev/null
+++ b/tests/models/groupvit/test_modeling_groupvit.py
@@ -0,0 +1,669 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch GroupViT model. """
+
+
+import inspect
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+import requests
+from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import GroupViTModel, GroupViTTextModel, GroupViTVisionModel
+ from transformers.models.groupvit.modeling_groupvit import GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import CLIPProcessor
+
+
+class GroupViTVisionModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ hidden_size=32,
+ depths=[6, 3, 3],
+ num_group_tokens=[64, 8, 0],
+ num_output_groups=[64, 8, 8],
+ num_attention_heads=4,
+ intermediate_size=37,
+ dropout=0.1,
+ attention_dropout=0.1,
+ initializer_range=0.02,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.depths = depths
+ self.num_hidden_layers = sum(depths)
+ self.expected_num_hidden_layers = len(depths) + 1
+ self.num_group_tokens = num_group_tokens
+ self.num_output_groups = num_output_groups
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ num_patches = (image_size // patch_size) ** 2
+ # no [CLS] token for GroupViT
+ self.seq_length = num_patches
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ return GroupViTVisionConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ depths=self.depths,
+ num_group_tokens=self.num_group_tokens,
+ num_output_groups=self.num_output_groups,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ dropout=self.dropout,
+ attention_dropout=self.attention_dropout,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values):
+ model = GroupViTVisionModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.num_output_groups[-1], self.hidden_size)
+ )
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as GROUPVIT does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (GroupViTVisionModel,) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = GroupViTVisionModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=GroupViTVisionConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="GroupViT does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+
+ expected_num_attention_outputs = sum(g > 0 for g in self.model_tester.num_group_tokens)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ # GroupViT returns attention grouping of each stage
+ self.assertEqual(len(attentions), sum(g > 0 for g in self.model_tester.num_group_tokens))
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ # GroupViT returns attention grouping of each stage
+ self.assertEqual(len(attentions), expected_num_attention_outputs)
+
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ # GroupViT returns attention grouping of each stage
+ self.assertEqual(len(self_attentions), expected_num_attention_outputs)
+ for i, self_attn in enumerate(self_attentions):
+ if self_attn is None:
+ continue
+
+ self.assertListEqual(
+ list(self_attentions[i].shape[-2:]),
+ [
+ self.model_tester.num_output_groups[i],
+ self.model_tester.num_output_groups[i - 1] if i > 0 else seq_len,
+ ],
+ )
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ # override since the attention mask from GroupViT is not used to compute loss, thus no grad
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = self.has_attentions
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+
+ outputs = model(**inputs)
+
+ output = outputs[0]
+
+ if config.is_encoder_decoder:
+ # Seq2Seq models
+ encoder_hidden_states = outputs.encoder_hidden_states[0]
+ encoder_hidden_states.retain_grad()
+
+ decoder_hidden_states = outputs.decoder_hidden_states[0]
+ decoder_hidden_states.retain_grad()
+
+ if self.has_attentions:
+ encoder_attentions = outputs.encoder_attentions[0]
+ encoder_attentions.retain_grad()
+
+ decoder_attentions = outputs.decoder_attentions[0]
+ decoder_attentions.retain_grad()
+
+ cross_attentions = outputs.cross_attentions[0]
+ cross_attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(encoder_hidden_states.grad)
+ self.assertIsNotNone(decoder_hidden_states.grad)
+
+ if self.has_attentions:
+ self.assertIsNotNone(encoder_attentions.grad)
+ self.assertIsNotNone(decoder_attentions.grad)
+ self.assertIsNotNone(cross_attentions.grad)
+ else:
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ hidden_states.retain_grad()
+
+ if self.has_attentions:
+ attentions = outputs.attentions[0]
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+
+ if self.has_attentions:
+ self.assertIsNone(attentions.grad)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = GroupViTVisionModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class GroupViTTextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ dropout=0.1,
+ attention_dropout=0.1,
+ max_position_embeddings=512,
+ initializer_range=0.02,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask
+
+ def get_config(self):
+ return GroupViTTextConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ dropout=self.dropout,
+ attention_dropout=self.attention_dropout,
+ max_position_embeddings=self.max_position_embeddings,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, input_ids, input_mask):
+ model = GroupViTTextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, input_mask = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class GroupViTTextModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (GroupViTTextModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = GroupViTTextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=GroupViTTextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="GroupViTTextModel does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = GroupViTTextModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class GroupViTModelTester:
+ def __init__(self, parent, is_training=True):
+ self.parent = parent
+ self.text_model_tester = GroupViTTextModelTester(parent)
+ self.vision_model_tester = GroupViTVisionModelTester(parent)
+ self.is_training = is_training
+
+ def prepare_config_and_inputs(self):
+ text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
+
+ config = self.get_config()
+
+ return config, input_ids, attention_mask, pixel_values
+
+ def get_config(self):
+ return GroupViTConfig.from_text_vision_configs(
+ self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
+ )
+
+ def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
+ model = GroupViTModel(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(input_ids, pixel_values, attention_mask)
+ self.parent.assertEqual(
+ result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
+ )
+ self.parent.assertEqual(
+ result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, attention_mask, pixel_values = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "return_loss": True,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class GroupViTModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (GroupViTModel,) if is_torch_available() else ()
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = GroupViTModelTester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="hidden_states are tested in individual model tests")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(reason="input_embeds are tested in individual model tests")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="tested in individual model tests")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(reason="GroupViTModel does not have input/output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ # override as the `logit_scale` parameter initilization is different for GROUPVIT
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # check if `logit_scale` is initilized as per the original implementation
+ if name == "logit_scale":
+ self.assertAlmostEqual(
+ param.data.item(),
+ np.log(1 / 0.07),
+ delta=1e-3,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # GROUPVIT needs pixel_values
+ traced_model = torch.jit.trace(model, (input_ids, pixel_values))
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ model.to(torch_device)
+ model.eval()
+
+ loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_load_vision_text_config(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Save GroupViTConfig and check if we can load GroupViTVisionConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ vision_config = GroupViTVisionConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
+
+ # Save GroupViTConfig and check if we can load GroupViTTextConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ text_config = GroupViTTextConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = GroupViTModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@require_vision
+@require_torch
+class GroupViTModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = "nvidia/groupvit-gcc-yfcc"
+ model = GroupViTModel.from_pretrained(model_name)
+ processor = CLIPProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
+ )
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ self.assertEqual(
+ outputs.logits_per_image.shape,
+ torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
+ )
+ self.assertEqual(
+ outputs.logits_per_text.shape,
+ torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
+ )
+
+ expected_logits = torch.tensor([[13.3523, 6.3629]])
+
+ self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py
index 0055b8346a4b..1e27690bd47a 100644
--- a/tests/models/hubert/test_modeling_hubert.py
+++ b/tests/models/hubert/test_modeling_hubert.py
@@ -16,12 +16,16 @@
import math
+import os
+import pickle
+import tempfile
import unittest
import pytest
from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
+from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@@ -45,6 +49,9 @@
)
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
class HubertModelTester:
def __init__(
@@ -299,6 +306,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_headmasking = False
@@ -417,6 +425,117 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
+ # Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "attention_mask",
+ "decoder_attention_mask",
+ "decoder_input_ids",
+ "input_features",
+ "input_ids",
+ "input_values",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = [
+ "attention_mask",
+ "bbox",
+ "input_features",
+ "input_ids",
+ "input_values",
+ "pixel_values",
+ "token_type_ids",
+ "visual_feats",
+ "visual_pos",
+ ]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except Exception as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py
index 156535d7a2b8..871d466d9712 100644
--- a/tests/models/hubert/test_modeling_tf_hubert.py
+++ b/tests/models/hubert/test_modeling_tf_hubert.py
@@ -539,7 +539,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant of panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/ibert/test_modeling_ibert.py b/tests/models/ibert/test_modeling_ibert.py
index f8e7b2da2b9e..78ba4d4604d1 100644
--- a/tests/models/ibert/test_modeling_ibert.py
+++ b/tests/models/ibert/test_modeling_ibert.py
@@ -116,6 +116,11 @@ def get_config(self):
quant_mode=True,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py
index 57b406f646bc..528532d4cd81 100644
--- a/tests/models/imagegpt/test_modeling_imagegpt.py
+++ b/tests/models/imagegpt/test_modeling_imagegpt.py
@@ -171,6 +171,12 @@ def get_config(
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 513
+ config.max_position_embeddings = 1024
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py
index ec2190598e56..e2d949611d78 100644
--- a/tests/models/layoutlm/test_modeling_layoutlm.py
+++ b/tests/models/layoutlm/test_modeling_layoutlm.py
@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else None
)
+ fx_compatible = True
def setUp(self):
self.model_tester = LayoutLMModelTester(self)
diff --git a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py
index bfcd729df153..3c38373163e4 100644
--- a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py
@@ -20,7 +20,7 @@
import tempfile
import unittest
-from transformers.testing_utils import require_detectron2, require_torch, slow, torch_device
+from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
from transformers.utils import is_detectron2_available, is_torch_available
from ...test_configuration_common import ConfigTester
@@ -260,7 +260,7 @@ def prepare_config_and_inputs_for_common(self):
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
- test_torchscript = False
+ test_torchscript = True
test_mismatched_shapes = False
all_model_classes = (
@@ -285,6 +285,16 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ @require_torch_multi_gpu
+ @unittest.skip(
+ reason=(
+ "LayoutLMV2 and its dependency `detectron2` have some layers using `add_module` which doesn't work well"
+ " with `nn.DataParallel`"
+ )
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
diff --git a/tests/models/layoutlmv2/test_processor_layoutlmv2.py b/tests/models/layoutlmv2/test_processor_layoutlmv2.py
index e822d177ca66..4f686155adc7 100644
--- a/tests/models/layoutlmv2/test_processor_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_processor_layoutlmv2.py
@@ -133,6 +133,39 @@ def test_save_load_pretrained_additional_features(self):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
+ @slow
+ def test_overflowing_tokens(self):
+ # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
+
+ from datasets import load_dataset
+
+ # set up
+ datasets = load_dataset("nielsr/funsd")
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
+
+ def preprocess_data(examples):
+ images = [Image.open(path).convert("RGB") for path in examples["image_path"]]
+ words = examples["words"]
+ boxes = examples["bboxes"]
+ word_labels = examples["ner_tags"]
+ encoded_inputs = processor(
+ images,
+ words,
+ boxes=boxes,
+ word_labels=word_labels,
+ padding="max_length",
+ truncation=True,
+ return_overflowing_tokens=True,
+ stride=50,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ )
+ return encoded_inputs
+
+ train_data = preprocess_data(datasets["train"])
+
+ self.assertEqual(len(train_data["image"]), len(train_data["input_ids"]))
+
# different use cases tests
@require_torch
@@ -182,10 +215,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president ā introductory remarks ā lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -203,10 +237,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] 7 itc limited report and accounts 2013 itc ā s brands : an asset for the nation the consumer needs and aspirations they fulfil, the benefit they generate for millions across itc ā s value chains, the future - ready capabilities that support them, and the value that they create for the country, have made itc ā s brands national assets, adding to india ā s competitiveness. it is itc ā s aspiration to be the no 1 fmcg player in the country, driven by its new fmcg businesses. a recent nielsen report has highlighted that itc's new fmcg businesses are the fastest growing among the top consumer goods companies operating in india. itc takes justifiable pride that, along with generating economic value, these celebrated indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. di wills * ; love delightfully soft skin? aia ans source : https : / / www. industrydocuments. ucsf. edu / docs / snbx0223 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
@slow
@@ -233,7 +268,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = "[CLS] hello world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -248,7 +283,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = "[CLS] hello world [SEP] [PAD] [PAD] [PAD]"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -287,7 +322,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = "[CLS] weirdly world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify labels
@@ -309,7 +344,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = "[CLS] my name is niels [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -349,10 +384,11 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] what's his name? [SEP] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president ā introductory remarks ā lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -367,8 +403,9 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
expected_decoding = "[CLS] what's the time [SEP] 7 itc limited report and accounts 2013 itc ā s [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -401,7 +438,7 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = "[CLS] what's his name? [SEP] hello world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -417,11 +454,11 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = "[CLS] how old is he? [SEP] hello world [SEP] [PAD] [PAD] [PAD]"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
expected_decoding = "[CLS] what's the time [SEP] my name is niels [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
index 1c3f8190c162..049caae64194 100644
--- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
@@ -181,7 +181,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -1634,11 +1634,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
@@ -1923,7 +1921,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1937,7 +1936,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
@@ -2183,7 +2183,9 @@ def test_maximum_encoding_length_single_input(self):
sequence = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
total_length = len(sequence["input_ids"])
- self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short")
+ self.assertGreater(
+ total_length, 4, "Issue with the testing sequence, please update it, it's too short"
+ )
# Test with max model input length
model_max_length = tokenizer.model_max_length
@@ -2193,7 +2195,9 @@ def test_maximum_encoding_length_single_input(self):
sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
self.assertGreater(
- total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short"
+ total_length1,
+ model_max_length,
+ "Issue with the testing sequence, please update it, it's too short",
)
# Simple
@@ -2232,7 +2236,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -2244,7 +2249,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
diff --git a/tests/models/layoutlmv3/__init__.py b/tests/models/layoutlmv3/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
new file mode 100644
index 000000000000..9d05a4b6658e
--- /dev/null
+++ b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
@@ -0,0 +1,213 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_pytesseract, require_torch
+from transformers.utils import is_pytesseract_available, is_torch_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_pytesseract_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor
+
+
+class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ apply_ocr=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.apply_ocr = apply_ocr
+
+ def prepare_feat_extract_dict(self):
+ return {"do_resize": self.do_resize, "size": self.size, "apply_ocr": self.apply_ocr}
+
+
+@require_torch
+@require_pytesseract
+class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = LayoutLMv3FeatureExtractor if is_pytesseract_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = LayoutLMv3FeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoding = feature_extractor(image_inputs[0], return_tensors="pt")
+ self.assertEqual(
+ encoding.pixel_values.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ self.assertIsInstance(encoding.words, list)
+ self.assertIsInstance(encoding.boxes, list)
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_LayoutLMv3_integration_test(self):
+ # with apply_OCR = True
+ feature_extractor = LayoutLMv3FeatureExtractor()
+
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
+
+ image = Image.open(ds[0]["file"]).convert("RGB")
+
+ encoding = feature_extractor(image, return_tensors="pt")
+
+ self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
+ self.assertEqual(len(encoding.words), len(encoding.boxes))
+
+ # fmt: off
+ # the words and boxes were obtained with Tesseract 4.1.1
+ expected_words = [['11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.', '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58', 'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00', 'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for', 'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to', 'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.', '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)', 'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President', 'āIntroductory', 'Remarksā', 'Lee', 'A.', 'Waller,', 'TRRF', 'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with', 'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific', 'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF', 'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers', 'which', 'the', 'public', 'refrigerated', 'warehousing', 'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from', 'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of', 'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.', 'Cecil,', 'University', 'of', 'Georgia', 'College', 'of', 'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University', 'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,', 'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen', 'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert', 'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers', 'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF', 'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer']] # noqa: E231
+ expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69], [141, 75, 216, 88], [230, 79, 280, 88], [142, 260, 218, 273], [230, 261, 255, 273], [143, 279, 218, 290], [231, 282, 290, 291], [143, 342, 218, 354], [231, 345, 289, 355], [202, 362, 227, 373], [143, 379, 220, 392], [231, 382, 291, 394], [144, 714, 220, 726], [231, 715, 256, 726], [144, 732, 220, 745], [232, 736, 291, 747], [144, 769, 218, 782], [231, 770, 256, 782], [141, 788, 202, 801], [215, 791, 274, 804], [143, 826, 204, 838], [215, 826, 240, 838], [142, 844, 202, 857], [215, 847, 274, 859], [334, 57, 427, 69], [440, 57, 522, 69], [369, 75, 461, 88], [469, 75, 516, 88], [528, 76, 562, 88], [570, 76, 667, 88], [675, 75, 711, 87], [721, 79, 778, 88], [789, 75, 840, 88], [369, 97, 470, 107], [484, 94, 507, 106], [518, 94, 562, 107], [576, 94, 655, 110], [668, 94, 792, 109], [804, 95, 829, 107], [369, 113, 465, 125], [477, 116, 547, 125], [562, 113, 658, 125], [671, 116, 748, 125], [761, 113, 811, 125], [369, 131, 465, 143], [477, 133, 548, 143], [563, 130, 698, 145], [710, 130, 802, 146], [336, 171, 412, 183], [423, 171, 572, 183], [582, 170, 716, 184], [728, 171, 817, 187], [829, 171, 844, 186], [338, 197, 482, 212], [507, 196, 557, 209], [569, 196, 595, 208], [610, 196, 702, 209], [505, 214, 583, 226], [595, 214, 656, 227], [670, 215, 807, 227], [335, 259, 543, 274], [556, 259, 708, 272], [372, 279, 422, 291], [435, 279, 460, 291], [474, 279, 574, 292], [587, 278, 664, 291], [676, 278, 738, 291], [751, 279, 834, 291], [372, 298, 434, 310], [335, 341, 483, 354], [497, 341, 655, 354], [667, 341, 728, 354], [740, 341, 825, 354], [335, 360, 430, 372], [442, 360, 534, 372], [545, 359, 687, 372], [697, 360, 754, 372], [765, 360, 823, 373], [334, 378, 428, 391], [440, 378, 577, 394], [590, 378, 705, 391], [720, 378, 801, 391], [334, 397, 400, 409], [370, 416, 529, 429], [544, 416, 576, 432], [587, 416, 665, 428], [677, 416, 814, 429], [372, 435, 452, 450], [465, 434, 495, 447], [511, 434, 600, 447], [611, 436, 637, 447], [649, 436, 694, 451], [705, 438, 824, 447], [369, 453, 452, 466], [464, 454, 509, 466], [522, 453, 611, 469], [625, 453, 792, 469], [370, 472, 556, 488], [570, 472, 684, 487], [697, 472, 718, 485], [732, 472, 835, 488], [369, 490, 411, 503], [425, 490, 484, 503], [496, 490, 635, 506], [645, 490, 707, 503], [718, 491, 761, 503], [771, 490, 840, 503], [336, 510, 374, 521], [388, 510, 447, 522], [460, 510, 489, 521], [503, 510, 580, 522], [592, 509, 736, 525], [745, 509, 770, 522], [781, 509, 840, 522], [338, 528, 434, 541], [448, 528, 596, 541], [609, 527, 687, 540], [700, 528, 792, 541], [336, 546, 397, 559], [407, 546, 431, 559], [443, 546, 525, 560], [537, 546, 680, 562], [688, 546, 714, 559], [722, 546, 837, 562], [336, 565, 449, 581], [461, 565, 485, 577], [497, 565, 665, 581], [681, 565, 718, 577], [732, 565, 837, 580], [337, 584, 438, 597], [452, 583, 521, 596], [535, 584, 677, 599], [690, 583, 787, 596], [801, 583, 825, 596], [338, 602, 478, 615], [492, 602, 530, 614], [543, 602, 638, 615], [650, 602, 676, 614], [688, 602, 788, 615], [802, 602, 843, 614], [337, 621, 502, 633], [516, 621, 615, 637], [629, 621, 774, 636], [789, 621, 827, 633], [337, 639, 418, 652], [432, 640, 571, 653], [587, 639, 731, 655], [743, 639, 769, 652], [780, 639, 841, 652], [338, 658, 440, 673], [455, 658, 491, 670], [508, 658, 602, 671], [616, 658, 638, 670], [654, 658, 835, 674], [337, 677, 429, 689], [337, 714, 482, 726], [495, 714, 548, 726], [561, 714, 683, 726], [338, 770, 461, 782], [474, 769, 554, 785], [489, 788, 562, 803], [576, 788, 643, 801], [656, 787, 751, 804], [764, 788, 844, 801], [334, 825, 421, 838], [430, 824, 574, 838], [584, 824, 723, 841], [335, 844, 450, 857], [464, 843, 583, 860], [628, 862, 755, 875], [769, 861, 848, 878]]] # noqa: E231
+ # fmt: on
+
+ self.assertListEqual(encoding.words, expected_words)
+ self.assertListEqual(encoding.boxes, expected_boxes)
+
+ # with apply_OCR = False
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+
+ encoding = feature_extractor(image, return_tensors="pt")
+
+ self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
diff --git a/tests/models/layoutlmv3/test_modeling_layoutlmv3.py b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py
new file mode 100644
index 000000000000..d5c8d42d2217
--- /dev/null
+++ b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py
@@ -0,0 +1,399 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch LayoutLMv3 model. """
+
+import copy
+import unittest
+
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ LayoutLMv3Config,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ )
+ from transformers.models.layoutlmv3.modeling_layoutlmv3 import LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor
+
+
+class LayoutLMv3ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ num_channels=3,
+ image_size=4,
+ patch_size=2,
+ text_seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=36,
+ num_hidden_layers=3,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ coordinate_size=6,
+ shape_size=6,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ range_bbox=1000,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.text_seq_length = text_seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.coordinate_size = coordinate_size
+ self.shape_size = shape_size
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+ self.range_bbox = range_bbox
+
+ # LayoutLMv3's sequence length equals the number of text tokens + number of patches + 1 (we add 1 for the CLS token)
+ self.text_seq_length = text_seq_length
+ self.image_seq_length = (image_size // patch_size) ** 2 + 1
+ self.seq_length = self.text_seq_length + self.image_seq_length
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.text_seq_length], self.vocab_size)
+
+ bbox = ids_tensor([self.batch_size, self.text_seq_length, 4], self.range_bbox)
+ # Ensure that bbox is legal
+ for i in range(bbox.shape[0]):
+ for j in range(bbox.shape[1]):
+ if bbox[i, j, 3] < bbox[i, j, 1]:
+ t = bbox[i, j, 3]
+ bbox[i, j, 3] = bbox[i, j, 1]
+ bbox[i, j, 1] = t
+ if bbox[i, j, 2] < bbox[i, j, 0]:
+ t = bbox[i, j, 2]
+ bbox[i, j, 2] = bbox[i, j, 0]
+ bbox[i, j, 0] = t
+
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.text_seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.text_seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
+
+ config = LayoutLMv3Config(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ coordinate_size=self.coordinate_size,
+ shape_size=self.shape_size,
+ input_size=self.image_size,
+ patch_size=self.patch_size,
+ )
+
+ return config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+
+ def create_and_check_model(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ model = LayoutLMv3Model(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # text + image
+ result = model(input_ids, pixel_values=pixel_values)
+ result = model(
+ input_ids, bbox=bbox, pixel_values=pixel_values, attention_mask=input_mask, token_type_ids=token_type_ids
+ )
+ result = model(input_ids, bbox=bbox, pixel_values=pixel_values, token_type_ids=token_type_ids)
+ result = model(input_ids, bbox=bbox, pixel_values=pixel_values)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ # text only
+ result = model(input_ids)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
+ )
+
+ # image only
+ result = model(pixel_values=pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.image_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_for_sequence_classification(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ config.num_labels = self.num_labels
+ model = LayoutLMv3ForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=sequence_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_for_token_classification(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ config.num_labels = self.num_labels
+ model = LayoutLMv3ForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=token_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.num_labels))
+
+ def create_and_check_for_question_answering(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ model = LayoutLMv3ForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ bbox,
+ pixel_values,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ ) = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "bbox": bbox,
+ "pixel_values": pixel_values,
+ "token_type_ids": token_type_ids,
+ "attention_mask": input_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase):
+
+ test_pruning = False
+ test_torchscript = False
+ test_mismatched_shapes = False
+
+ all_model_classes = (
+ (
+ LayoutLMv3Model,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3ForQuestionAnswering,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ def setUp(self):
+ self.model_tester = LayoutLMv3ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LayoutLMv3Config, hidden_size=37)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = copy.deepcopy(inputs_dict)
+ if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ inputs_dict = {
+ k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
+ if isinstance(v, torch.Tensor) and v.ndim > 1
+ else v
+ for k, v in inputs_dict.items()
+ }
+ if return_labels:
+ if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
+ elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ elif model_class in [
+ *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
+ ]:
+ inputs_dict["labels"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ elif model_class in [
+ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
+ ]:
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.text_seq_length),
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ return inputs_dict
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_various_embeddings(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for type in ["absolute", "relative_key", "relative_key_query"]:
+ config_and_inputs[0].position_embedding_type = type
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LayoutLMv3Model.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+class LayoutLMv3ModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return LayoutLMv3FeatureExtractor(apply_ocr=False) if is_vision_available() else None
+
+ @slow
+ def test_inference_no_head(self):
+ model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
+
+ input_ids = torch.tensor([[1, 2]])
+ bbox = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).unsqueeze(0)
+
+ # forward pass
+ outputs = model(
+ input_ids=input_ids.to(torch_device),
+ bbox=bbox.to(torch_device),
+ pixel_values=pixel_values.to(torch_device),
+ )
+
+ # verify the logits
+ expected_shape = torch.Size((1, 199, 768))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-0.0529, 0.3618, 0.1632], [-0.1587, -0.1667, -0.0400], [-0.1557, -0.1671, -0.0505]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/layoutlmv3/test_processor_layoutlmv3.py b/tests/models/layoutlmv3/test_processor_layoutlmv3.py
new file mode 100644
index 000000000000..a01b0a00cd90
--- /dev/null
+++ b/tests/models/layoutlmv3/test_processor_layoutlmv3.py
@@ -0,0 +1,446 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+from typing import List
+
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
+from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
+from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow
+from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available
+
+
+if is_pytesseract_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3Processor
+
+
+@require_pytesseract
+@require_tokenizers
+class LayoutLMv3ProcessorTest(unittest.TestCase):
+ tokenizer_class = LayoutLMv3Tokenizer
+ rust_tokenizer_class = LayoutLMv3TokenizerFast
+
+ def setUp(self):
+ # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ ]
+ self.tmpdirname = tempfile.mkdtemp()
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ feature_extractor_map = {
+ "do_resize": True,
+ "size": 224,
+ "apply_ocr": True,
+ }
+
+ self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(feature_extractor_map) + "\n")
+
+ def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
+ return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
+
+ def get_feature_extractor(self, **kwargs):
+ return LayoutLMv3FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = LayoutLMv3Processor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, (LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast))
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = LayoutLMv3Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer())
+ processor.save_pretrained(self.tmpdirname)
+
+ # slow tokenizer
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
+
+ processor = LayoutLMv3Processor.from_pretrained(
+ self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, LayoutLMv3Tokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+ # fast tokenizer
+ tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
+
+ processor = LayoutLMv3Processor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, LayoutLMv3TokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+
+# different use cases tests
+@require_torch
+@require_pytesseract
+class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
+ @cached_property
+ def get_images(self):
+ # we verify our implementation on 2 document images from the DocVQA dataset
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
+
+ image_1 = Image.open(ds[0]["file"]).convert("RGB")
+ image_2 = Image.open(ds[1]["file"]).convert("RGB")
+
+ return image_1, image_2
+
+ @cached_property
+ def get_tokenizers(self):
+ slow_tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ fast_tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ return [slow_tokenizer, fast_tokenizer]
+
+ @slow
+ def test_processor_case_1(self):
+ # case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True
+
+ feature_extractor = LayoutLMv3FeatureExtractor()
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ input_feat_extract = feature_extractor(images[0], return_tensors="pt")
+ input_processor = processor(images[0], return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify image
+ self.assertAlmostEqual(
+ input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
+ )
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ input_feat_extract = feature_extractor(images, return_tensors="pt")
+ input_processor = processor(images, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify images
+ self.assertAlmostEqual(
+ input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
+ )
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITCās Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITCās value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITCās brands national assets, adding to Indiaās competitiveness. It is ITCās aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ @slow
+ def test_processor_case_2(self):
+ # case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ words = ["hello", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["input_ids", "bbox", "attention_mask", "pixel_values"]
+ actual_keys = list(input_processor.keys())
+ for key in expected_keys:
+ self.assertIn(key, actual_keys)
+
+ # verify input_ids
+ expected_decoding = " hello world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " hello world"
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [
+ [0, 0, 0, 0],
+ [3, 2, 5, 1],
+ [6, 7, 4, 2],
+ [3, 9, 2, 4],
+ [1, 1, 2, 3],
+ [1, 1, 2, 3],
+ [0, 0, 0, 0],
+ ]
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ @slow
+ def test_processor_case_3(self):
+ # case 3: token classification (training), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ words = ["weirdly", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ word_labels = [1, 2]
+ input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " weirdly world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify labels
+ expected_labels = [-100, 1, -100, 2, -100]
+ self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels)
+
+ # batched
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ word_labels = [[1, 2], [6, 3, 10, 2]]
+ input_processor = processor(
+ images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt"
+ )
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " my name is niels"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [
+ [0, 0, 0, 0],
+ [3, 2, 5, 1],
+ [6, 7, 4, 2],
+ [3, 9, 2, 4],
+ [1, 1, 2, 3],
+ [1, 1, 2, 3],
+ [0, 0, 0, 0],
+ ]
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ # verify labels
+ expected_labels = [-100, 6, 3, 10, 2, -100, -100]
+ self.assertListEqual(input_processor.labels[1].tolist(), expected_labels)
+
+ @slow
+ def test_processor_case_4(self):
+ # case 4: visual question answering (inference), apply_ocr=True
+
+ feature_extractor = LayoutLMv3FeatureExtractor()
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ question = "What's his name?"
+ input_processor = processor(images[0], question, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ questions = ["How old is he?", "what's the time"]
+ input_processor = processor(
+ images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt"
+ )
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ # fmt: off
+ expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [74, 136, 161, 158], [0, 0, 0, 0]] # noqa: E231
+ # fmt: on
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ @slow
+ def test_processor_case_5(self):
+ # case 5: visual question answering (inference), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ question = "What's his name?"
+ words = ["hello", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " What's his name? hello world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ questions = ["How old is he?", "what's the time"]
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " How old is he? hello world"
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ expected_decoding = " what's the time my name is niels"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [0, 0, 0, 0]]
+ self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
new file mode 100644
index 000000000000..239939ca2696
--- /dev/null
+++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
@@ -0,0 +1,2349 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import re
+import shutil
+import tempfile
+import unittest
+from typing import List
+
+from transformers import AddedToken, LayoutLMv3TokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
+from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer
+from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
+
+from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings
+
+
+@require_tokenizers
+@require_pandas
+class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+ tokenizer_class = LayoutLMv3Tokenizer
+ rust_tokenizer_class = LayoutLMv3TokenizerFast
+ test_rust_tokenizer = True
+ # determined by the tokenization algortihm and the way it's decoded by the fast tokenizers
+ space_between_special_tokens = False
+ test_seq2seq = False
+ from_pretrained_kwargs = {"cls_token": ""}
+
+ def get_words_and_boxes(self):
+ words = ["lower", "newer"]
+ boxes = [[423, 237, 440, 251], [427, 272, 441, 287]]
+
+ return words, boxes
+
+ def get_words_and_boxes_batch(self):
+ words = [["lower", "newer"], ["new", "low"]]
+ boxes = [
+ [[423, 237, 440, 251], [427, 272, 441, 287]],
+ [[961, 885, 992, 912], [256, 38, 330, 58]],
+ ]
+
+ return words, boxes
+
+ def get_question_words_and_boxes(self):
+ question = "what's his name?"
+ words = ["lower", "newer"]
+ boxes = [[423, 237, 440, 251], [427, 272, 441, 287]]
+
+ return question, words, boxes
+
+ def get_question_words_and_boxes_batch(self):
+ questions = ["what's his name?", "how is he called?"]
+ words = [["lower", "newer"], ["newer", "lower"]]
+ boxes = [
+ [[423, 237, 440, 251], [427, 272, 441, 287]],
+ [[256, 38, 330, 58], [256, 38, 330, 58]],
+ ]
+
+ return questions, words, boxes
+
+ def setUp(self):
+ super().setUp()
+
+ # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ ]
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ def get_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return LayoutLMv3TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "lower newer"
+ output_text = "lower newer"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
+ text = "lower newer"
+ bpe_tokens = ["Ä low", "er", "Ä ", "n", "e", "w", "er"]
+ tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
+ self.assertListEqual(tokens, bpe_tokens)
+
+ input_tokens = tokens + [tokenizer.unk_token]
+ input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutlmv3-base")
+
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ text = tokenizer.encode(
+ question.split(),
+ boxes=[tokenizer.pad_token_box for _ in range(len(question.split()))],
+ add_special_tokens=False,
+ )
+ text_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_pair == [0] + text + [2] + [2] + text_2 + [2]
+
+ def test_add_special_tokens(self):
+ tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ special_token = "[SPECIAL_TOKEN]"
+ special_token_box = [1000, 1000, 1000, 1000]
+
+ tokenizer.add_special_tokens({"cls_token": special_token})
+ encoded_special_token = tokenizer.encode(
+ [special_token], boxes=[special_token_box], add_special_tokens=False
+ )
+ self.assertEqual(len(encoded_special_token), 1)
+
+ decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True)
+ self.assertTrue(special_token not in decoded)
+
+ def test_add_tokens_tokenizer(self):
+ tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ vocab_size = tokenizer.vocab_size
+ all_size = len(tokenizer)
+
+ self.assertNotEqual(vocab_size, 0)
+
+ # We usually have added tokens from the start in tests because our vocab fixtures are
+ # smaller than the original vocabs - let's not assert this
+ # self.assertEqual(vocab_size, all_size)
+
+ new_toks = ["aaaaa", "bbbbbb", "cccccccccdddddddd"]
+ added_toks = tokenizer.add_tokens(new_toks)
+ vocab_size_2 = tokenizer.vocab_size
+ all_size_2 = len(tokenizer)
+
+ self.assertNotEqual(vocab_size_2, 0)
+ self.assertEqual(vocab_size, vocab_size_2)
+ self.assertEqual(added_toks, len(new_toks))
+ self.assertEqual(all_size_2, all_size + len(new_toks))
+
+ words = "aaaaa bbbbbb low cccccccccdddddddd l".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ self.assertGreaterEqual(len(tokens), 4)
+ self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
+
+ new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
+ added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
+ vocab_size_3 = tokenizer.vocab_size
+ all_size_3 = len(tokenizer)
+
+ self.assertNotEqual(vocab_size_3, 0)
+ self.assertEqual(vocab_size, vocab_size_3)
+ self.assertEqual(added_toks_2, len(new_toks_2))
+ self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
+
+ words = ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ tokens = tokenizer.encode(
+ words,
+ boxes=boxes,
+ add_special_tokens=False,
+ )
+
+ self.assertGreaterEqual(len(tokens), 6)
+ self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[0], tokens[1])
+ self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[-2], tokens[-3])
+ self.assertEqual(tokens[0], tokenizer.eos_token_id)
+ self.assertEqual(tokens[-2], tokenizer.pad_token_id)
+
+ @require_tokenizers
+ def test_encode_decode_with_spaces(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)]
+ tokenizer.add_tokens(new_toks)
+ input = "[ABC][DEF][ABC][DEF]"
+ if self.space_between_special_tokens:
+ output = "[ABC] [DEF] [ABC] [DEF]"
+ else:
+ output = input
+ encoded = tokenizer.encode(input.split(), boxes=boxes, add_special_tokens=False)
+ decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
+ self.assertIn(decoded, [output, output.lower()])
+
+ @unittest.skip("Not implemented")
+ def test_right_and_left_truncation(self):
+ pass
+
+ def test_encode_plus_with_padding(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ padding_size = 10
+ padding_idx = tokenizer.pad_token_id
+
+ encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_special_tokens_mask=True)
+ input_ids = encoded_sequence["input_ids"]
+ special_tokens_mask = encoded_sequence["special_tokens_mask"]
+ sequence_length = len(input_ids)
+
+ # Test 'longest' and 'no_padding' don't do anything
+ tokenizer.padding_side = "right"
+
+ not_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ padding=False,
+ return_special_tokens_mask=True,
+ )
+ not_padded_input_ids = not_padded_sequence["input_ids"]
+
+ not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
+ not_padded_sequence_length = len(not_padded_input_ids)
+
+ self.assertTrue(sequence_length == not_padded_sequence_length)
+ self.assertTrue(input_ids == not_padded_input_ids)
+ self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
+
+ not_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ padding=False,
+ return_special_tokens_mask=True,
+ )
+ not_padded_input_ids = not_padded_sequence["input_ids"]
+
+ not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
+ not_padded_sequence_length = len(not_padded_input_ids)
+
+ self.assertTrue(sequence_length == not_padded_sequence_length)
+ self.assertTrue(input_ids == not_padded_input_ids)
+ self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
+
+ # Test right padding
+ tokenizer.padding_side = "right"
+
+ right_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=sequence_length + padding_size,
+ padding="max_length",
+ return_special_tokens_mask=True,
+ )
+ right_padded_input_ids = right_padded_sequence["input_ids"]
+
+ right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"]
+ right_padded_sequence_length = len(right_padded_input_ids)
+
+ self.assertTrue(sequence_length + padding_size == right_padded_sequence_length)
+ self.assertTrue(input_ids + [padding_idx] * padding_size == right_padded_input_ids)
+ self.assertTrue(special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask)
+
+ # Test left padding
+ tokenizer.padding_side = "left"
+ left_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=sequence_length + padding_size,
+ padding="max_length",
+ return_special_tokens_mask=True,
+ )
+ left_padded_input_ids = left_padded_sequence["input_ids"]
+ left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"]
+ left_padded_sequence_length = len(left_padded_input_ids)
+
+ self.assertTrue(sequence_length + padding_size == left_padded_sequence_length)
+ self.assertTrue([padding_idx] * padding_size + input_ids == left_padded_input_ids)
+ self.assertTrue([1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask)
+
+ if "token_type_ids" in tokenizer.model_input_names:
+ token_type_ids = encoded_sequence["token_type_ids"]
+ left_padded_token_type_ids = left_padded_sequence["token_type_ids"]
+ right_padded_token_type_ids = right_padded_sequence["token_type_ids"]
+
+ assert token_type_ids + [0] * padding_size == right_padded_token_type_ids
+ assert [0] * padding_size + token_type_ids == left_padded_token_type_ids
+
+ if "attention_mask" in tokenizer.model_input_names:
+ attention_mask = encoded_sequence["attention_mask"]
+ right_padded_attention_mask = right_padded_sequence["attention_mask"]
+ left_padded_attention_mask = left_padded_sequence["attention_mask"]
+
+ self.assertTrue(attention_mask + [0] * padding_size == right_padded_attention_mask)
+ self.assertTrue([0] * padding_size + attention_mask == left_padded_attention_mask)
+
+ def test_internal_consistency(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ tokens = []
+ for word in words:
+ tokens.extend(tokenizer.tokenize(word))
+ ids = tokenizer.convert_tokens_to_ids(tokens)
+ ids_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ self.assertListEqual(ids, ids_2)
+
+ tokens_2 = tokenizer.convert_ids_to_tokens(ids)
+ self.assertNotEqual(len(tokens_2), 0)
+ text_2 = tokenizer.decode(ids)
+ self.assertIsInstance(text_2, str)
+
+ output_text = " lower newer"
+ self.assertEqual(text_2, output_text)
+
+ def test_mask_output(self):
+ tokenizers = self.get_tokenizers(fast=False, do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ if (
+ tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer"
+ and "token_type_ids" in tokenizer.model_input_names
+ ):
+ information = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True)
+ sequences, mask = information["input_ids"], information["token_type_ids"]
+ self.assertEqual(len(sequences), len(mask))
+
+ def test_number_of_added_tokens(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ # test 1: single sequence
+ words, boxes = self.get_words_and_boxes()
+
+ sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ attached_sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+
+ # Method is implemented (e.g. not GPT-2)
+ if len(attached_sequences) != 2:
+ self.assertEqual(
+ tokenizer.num_special_tokens_to_add(pair=False), len(attached_sequences) - len(sequences)
+ )
+
+ # test 2: two sequences
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=False)
+ attached_sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=True)
+
+ # Method is implemented (e.g. not GPT-2)
+ if len(attached_sequences) != 2:
+ self.assertEqual(
+ tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences)
+ )
+
+ def test_padding_to_max_length(self):
+ """We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated"""
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ padding_size = 10
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ padding_idx = tokenizer.pad_token_id
+
+ # Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "right"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ # FIXME: the next line should be padding(max_length) to avoid warning
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, pad_to_max_length=True
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
+
+ # Check that nothing is done when a maximum length is not specified
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes, pad_to_max_length=True)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ def test_padding(self, max_length=50):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
+ pad_token_id = tokenizer_p.pad_token_id
+
+ # Encode - Simple input
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.encode(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode(words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
+
+ # Encode - Pair input
+ question, words, boxes = self.get_question_words_and_boxes()
+ input_r = tokenizer_r.encode(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(question, words, boxes=boxes, padding=True)
+ input_p = tokenizer_p.encode(question, words, boxes=boxes, padding="longest")
+ self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
+
+ # Encode_plus - Simple input
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Encode_plus - Pair input
+ question, words, boxes = self.get_question_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ input_p = tokenizer_p.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(question, words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode_plus(question, words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Batch_encode_plus - Simple input
+ words, boxes = self.get_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="max_length",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="longest",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding=True,
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.batch_encode_plus(words, boxes=boxes, padding=True)
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Batch_encode_plus - Pair input
+ questions, words, boxes = self.get_question_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ padding=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ padding="longest",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Using pad on single examples after tokenization
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.pad(input_p)
+
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+
+ # Using pad on single examples after tokenization
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+
+ # Using pad after tokenization
+ words, boxes = self.get_words_and_boxes_batch()
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_p = tokenizer_r.pad(input_p)
+
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Using pad after tokenization
+ words, boxes = self.get_words_and_boxes_batch()
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ def test_call(self):
+ # Tests that all call wrap to encode_plus and batch_encode_plus
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Test not batched
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ # Test not batched pairs
+ question, words, boxes = self.get_question_words_and_boxes()
+ encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ # Test batched
+ words, boxes = self.get_words_and_boxes_batch()
+ encoded_sequences_1 = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ def test_batch_encode_plus_batch_sequence_length(self):
+ # Tests that all encoded values have the correct size
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ encoded_sequences = [
+ tokenizer.encode_plus(words_example, boxes=boxes_example)
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes, padding=False)
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ maximum_length = len(
+ max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len)
+ )
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences_padded = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=maximum_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+
+ encoded_sequences_batch_padded = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=True
+ )
+ self.assertListEqual(
+ encoded_sequences_padded,
+ self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded),
+ )
+
+ # check 'longest' is unsensitive to a max length
+ encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=True
+ )
+ encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding="longest"
+ )
+ for key in encoded_sequences_batch_padded_1.keys():
+ self.assertListEqual(
+ encoded_sequences_batch_padded_1[key],
+ encoded_sequences_batch_padded_2[key],
+ )
+
+ # check 'no_padding' is unsensitive to a max length
+ encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=False
+ )
+ encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding=False
+ )
+ for key in encoded_sequences_batch_padded_1.keys():
+ self.assertListEqual(
+ encoded_sequences_batch_padded_1[key],
+ encoded_sequences_batch_padded_2[key],
+ )
+
+ @unittest.skip("batch_encode_plus does not handle overflowing tokens.")
+ def test_batch_encode_plus_overflowing_tokens(self):
+ pass
+
+ def test_batch_encode_plus_padding(self):
+ # Test that padded sequences are equivalent between batch_encode_plus and encode_plus
+
+ # Right padding tests
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ max_length = 100
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=max_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ # Left padding tests
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ tokenizer.padding_side = "left"
+ words, boxes = self.get_words_and_boxes_batch()
+
+ max_length = 100
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=max_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ def test_padding_to_multiple_of(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if tokenizer.pad_token is None:
+ self.skipTest("No padding token.")
+ else:
+ words, boxes = self.get_words_and_boxes()
+
+ # empty_tokens = tokenizer([""], [[]], padding=True, pad_to_multiple_of=8)
+ normal_tokens = tokenizer(words, boxes=boxes, padding=True, pad_to_multiple_of=8)
+ # for key, value in empty_tokens.items():
+ # self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+ for key, value in normal_tokens.items():
+ self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ normal_tokens = tokenizer(words, boxes=boxes, pad_to_multiple_of=8)
+ for key, value in normal_tokens.items():
+ self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ # Should also work with truncation
+ normal_tokens = tokenizer(words, boxes=boxes, padding=True, truncation=True, pad_to_multiple_of=8)
+ for key, value in normal_tokens.items():
+ self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ # truncation to something which is not a multiple of pad_to_multiple_of raises an error
+ self.assertRaises(
+ ValueError,
+ tokenizer.__call__,
+ words,
+ boxes=boxes,
+ padding=True,
+ truncation=True,
+ max_length=12,
+ pad_to_multiple_of=8,
+ )
+
+ def test_tokenizer_slow_store_full_signature(self):
+ signature = inspect.signature(self.tokenizer_class.__init__)
+ tokenizer = self.get_tokenizer()
+
+ for parameter_name, parameter in signature.parameters.items():
+ if parameter.default != inspect.Parameter.empty:
+ self.assertIn(parameter_name, tokenizer.init_kwargs)
+
+ def test_build_inputs_with_special_tokens(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ # Input tokens id
+ words, boxes = self.get_words_and_boxes()
+ input_simple = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False)
+ input_pair = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False)
+
+ # Generate output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
+ self.assertEqual(output_p, output_r)
+
+ # Generate pair output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
+ self.assertEqual(output_p, output_r)
+
+ def test_special_tokens_mask_input_pairs(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ encoded_sequence_dict = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ return_special_tokens_mask=True,
+ # add_prefix_space=False,
+ )
+ encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
+ special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
+ self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
+
+ filtered_sequence = [
+ (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
+ ]
+ filtered_sequence = [x for x in filtered_sequence if x is not None]
+ self.assertEqual(encoded_sequence, filtered_sequence)
+
+ def test_special_tokens_mask(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ # Testing single inputs
+ encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ encoded_sequence_dict = tokenizer.encode_plus(
+ words, boxes=boxes, add_special_tokens=True, return_special_tokens_mask=True
+ )
+ encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
+ special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
+ self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
+
+ filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
+ self.assertEqual(encoded_sequence, filtered_sequence)
+
+ def test_save_and_load_tokenizer(self):
+ # safety check on max_len default value so we are sure the test works
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ self.assertNotEqual(tokenizer.model_max_length, 42)
+
+ # Now let's start the test
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Isolate this from the other tests because we save additional tokens/etc
+ words, boxes = self.get_words_and_boxes()
+ tmpdirname = tempfile.mkdtemp()
+
+ before_tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ before_vocab = tokenizer.get_vocab()
+ tokenizer.save_pretrained(tmpdirname)
+
+ after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
+ after_tokens = after_tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ after_vocab = after_tokenizer.get_vocab()
+ self.assertListEqual(before_tokens, after_tokens)
+ self.assertDictEqual(before_vocab, after_vocab)
+
+ shutil.rmtree(tmpdirname)
+
+ def test_right_and_left_padding(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ sequence = "Sequence"
+ padding_size = 10
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, sequence)
+
+ padding_idx = tokenizer.pad_token_id
+
+ # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "right"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length"
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
+
+ # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "left"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length"
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
+
+ # RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes, padding=True)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ tokenizer.padding_side = "left"
+ padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding="longest")
+ padded_sequence_left_length = len(padded_sequence_left)
+ assert sequence_length == padded_sequence_left_length
+ assert encoded_sequence == padded_sequence_left
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ tokenizer.padding_side = "left"
+ padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding=False)
+ padded_sequence_left_length = len(padded_sequence_left)
+ assert sequence_length == padded_sequence_left_length
+ assert encoded_sequence == padded_sequence_left
+
+ def test_token_type_ids(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ # test 1: single sequence
+ words, boxes = self.get_words_and_boxes()
+
+ output = tokenizer(words, boxes=boxes, return_token_type_ids=True)
+
+ # Assert that the token type IDs have the same length as the input IDs
+ self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
+
+ # Assert that the token type IDs have the same length as the attention mask
+ self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
+
+ self.assertIn(0, output["token_type_ids"])
+ self.assertNotIn(1, output["token_type_ids"])
+
+ # test 2: two sequences (question + words)
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ output = tokenizer(question, words, boxes, return_token_type_ids=True)
+
+ # Assert that the token type IDs have the same length as the input IDs
+ self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
+
+ # Assert that the token type IDs have the same length as the attention mask
+ self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
+
+ self.assertIn(0, output["token_type_ids"])
+
+ def test_offsets_mapping(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ text = ["a", "wonderful", "test"]
+ boxes = [[1, 8, 12, 20] for _ in range(len(text))]
+
+ # No pair
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text,
+ boxes=boxes,
+ return_special_tokens_mask=True,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(False)
+ offsets = tokens_with_offsets["offset_mapping"]
+
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+
+ # Pairs
+ text = "what's his name"
+ pair = ["a", "wonderful", "test"]
+ boxes = [[1, 8, 12, 20] for _ in range(len(pair))]
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text,
+ pair,
+ boxes=boxes,
+ return_special_tokens_mask=True,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(True)
+ offsets = tokens_with_offsets["offset_mapping"]
+
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+
+ @require_torch
+ @slow
+ def test_torch_encode_plus_sent_to_model(self):
+ import torch
+
+ from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
+
+ MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
+
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
+ return
+
+ config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
+ config = config_class()
+
+ if config.is_encoder_decoder or config.pad_token_id is None:
+ return
+
+ model = model_class(config)
+
+ # Make sure the model contains at least the full vocabulary size in its embedding matrix
+ is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight")
+ assert (
+ (model.get_input_embeddings().weight.shape[0] >= len(tokenizer))
+ if is_using_common_embeddings
+ else True
+ )
+
+ # Build sequence
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_tensors="pt")
+ batch_encoded_sequence = tokenizer.batch_encode_plus(
+ [words, words], boxes=[boxes, boxes], return_tensors="pt"
+ )
+
+ # We add dummy pixel_values keys (as LayoutLMv3 actually also requires a feature extractor
+ # to prepare the image input)
+ encoded_sequence["pixel_values"] = torch.randn(1, 3, 224, 224)
+ batch_encoded_sequence["pixel_values"] = torch.randn(2, 3, 224, 224)
+
+ # This should not fail
+ with torch.no_grad(): # saves some time
+ model(**encoded_sequence)
+ model(**batch_encoded_sequence)
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ words, boxes = self.get_words_and_boxes()
+
+ ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+ rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_tokenization_python_rust_equals(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ words, boxes = self.get_words_and_boxes()
+
+ # Ensure basic input match
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key])
+
+ input_pairs_p = tokenizer_p.encode_plus(words, boxes=boxes)
+ input_pairs_r = tokenizer_r.encode_plus(words, boxes=boxes)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
+
+ words = ["hello" for _ in range(1000)]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(1000)]
+
+ # Ensure truncation match
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=512, truncation=True)
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=512, truncation=True)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key])
+
+ # Ensure truncation with stride match
+ input_p = tokenizer_p.encode_plus(
+ words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
+ input_r = tokenizer_r.encode_plus(
+ words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key][0])
+
+ def test_embeded_special_tokens(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ words, boxes = self.get_words_and_boxes()
+ tokens_r = tokenizer_r.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ )
+ tokens_p = tokenizer_p.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ )
+
+ for key in tokens_p.keys():
+ self.assertEqual(tokens_r[key], tokens_p[key])
+
+ if "token_type_ids" in tokens_r:
+ self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
+
+ tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
+ tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
+ self.assertSequenceEqual(tokens_r, tokens_p)
+
+ def test_compare_add_special_tokens(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
+
+ words, boxes = self.get_words_and_boxes()
+ # tokenize()
+ no_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=False)
+ with_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=True)
+ self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
+
+ # encode()
+ no_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=True)
+ self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
+
+ # encode_plus()
+ no_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ self.assertEqual(
+ len(no_special_tokens[key]),
+ len(with_special_tokens[key]) - simple_num_special_tokens_to_add,
+ )
+
+ # # batch_encode_plus
+ words, boxes = self.get_words_and_boxes_batch()
+
+ no_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
+ self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
+
+ @slow
+ def test_layoutlmv3_truncation_integration_test(self):
+ words, boxes = self.get_words_and_boxes()
+
+ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", model_max_length=512)
+
+ for i in range(12, 512):
+ new_encoded_inputs = tokenizer.encode(words, boxes=boxes, max_length=i, truncation=True)
+
+ # Ensure that the input IDs are less than the max length defined.
+ self.assertLessEqual(len(new_encoded_inputs), i)
+
+ tokenizer.model_max_length = 20
+ new_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True)
+ dropped_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True)
+
+ # Ensure that the input IDs are still truncated when no max_length is specified
+ self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
+ self.assertLessEqual(len(new_encoded_inputs), 20)
+
+ @is_pt_tf_cross_test
+ def test_batch_encode_plus_tensors(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ # A Tensor cannot be build by sequences which are not the same size
+ self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="pt")
+ self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="tf")
+
+ if tokenizer.pad_token_id is None:
+ self.assertRaises(
+ ValueError,
+ tokenizer.batch_encode_plus,
+ words,
+ boxes=boxes,
+ padding=True,
+ return_tensors="pt",
+ )
+ self.assertRaises(
+ ValueError,
+ tokenizer.batch_encode_plus,
+ words,
+ boxes=boxes,
+ padding="longest",
+ return_tensors="tf",
+ )
+ else:
+ pytorch_tensor = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True, return_tensors="pt")
+ tensorflow_tensor = tokenizer.batch_encode_plus(
+ words, boxes=boxes, padding="longest", return_tensors="tf"
+ )
+ encoded_sequences = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True)
+
+ for key in encoded_sequences.keys():
+ pytorch_value = pytorch_tensor[key].tolist()
+ tensorflow_value = tensorflow_tensor[key].numpy().tolist()
+ encoded_value = encoded_sequences[key]
+
+ self.assertEqual(pytorch_value, tensorflow_value, encoded_value)
+
+ def test_sequence_ids(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ if not tokenizer.is_fast:
+ continue
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ seq_0 = "Test this method."
+ seq_1 = ["With", "these", "inputs."]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(seq_1))]
+
+ # We want to have sequence 0 and sequence 1 are tagged
+ # respectively with 0 and 1 token_ids
+ # (regardless of whether the model use token type ids)
+ # We use this assumption in the QA pipeline among other place
+ output = tokenizer(seq_0.split(), boxes=boxes)
+ self.assertIn(0, output.sequence_ids())
+
+ output = tokenizer(seq_0, seq_1, boxes=boxes)
+ self.assertIn(0, output.sequence_ids())
+ self.assertIn(1, output.sequence_ids())
+
+ if tokenizer.num_special_tokens_to_add(pair=True):
+ self.assertIn(None, output.sequence_ids())
+
+ def test_special_tokens_initialization(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ added_tokens = [AddedToken("", lstrip=True)]
+
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+ words = "Hey this is a token".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+ r_output = tokenizer_r.encode(words, boxes=boxes)
+
+ special_token_id = tokenizer_r.encode(
+ [""], boxes=[1000, 1000, 1000, 1000], add_special_tokens=False
+ )[0]
+
+ self.assertTrue(special_token_id in r_output)
+
+ if self.test_slow_tokenizer:
+ tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
+ )
+ tokenizer_p = self.tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+
+ words = "Hey this is a token".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ p_output = tokenizer_p.encode(words, boxes=boxes)
+ cr_output = tokenizer_cr.encode(words, boxes=boxes)
+
+ self.assertEqual(p_output, r_output)
+ self.assertEqual(cr_output, r_output)
+ self.assertTrue(special_token_id in p_output)
+ self.assertTrue(special_token_id in cr_output)
+
+ def test_training_new_tokenizer(self):
+ # This feature only exists for fast tokenizers
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_rust_tokenizer()
+ new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
+
+ # Test we can use the new tokenizer with something not seen during training
+ text = [["this", "is", "the"], ["how", "are", "you"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8], [1, 3, 4, 8]], [[5, 6, 7, 8], [4, 5, 6, 7], [3, 9, 2, 7]]]
+ inputs = new_tokenizer(text, boxes=boxes)
+ self.assertEqual(len(inputs["input_ids"]), 2)
+ decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
+ expected_result = " this is the"
+
+ if tokenizer.backend_tokenizer.normalizer is not None:
+ expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
+ self.assertEqual(expected_result, decoded_input)
+
+ # We check that the parameters of the tokenizer remained the same
+ # Check we have the same number of added_tokens for both pair and non-pair inputs.
+ self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
+ self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
+
+ # Check we have the correct max_length for both pair and non-pair inputs.
+ self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
+ self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
+
+ # Assert the set of special tokens match as we didn't ask to change them
+ self.assertSequenceEqual(
+ tokenizer.all_special_tokens_extended,
+ new_tokenizer.all_special_tokens_extended,
+ )
+
+ self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
+
+ def test_training_new_tokenizer_with_special_tokens_change(self):
+ # This feature only exists for fast tokenizers
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_rust_tokenizer()
+ # Test with a special tokens map
+ class_signature = inspect.signature(tokenizer.__class__)
+ if "cls_token" in class_signature.parameters:
+ new_tokenizer = tokenizer.train_new_from_iterator(
+ SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: ""}
+ )
+ cls_id = new_tokenizer.get_vocab()[""]
+ self.assertEqual(new_tokenizer.cls_token, "")
+ self.assertEqual(new_tokenizer.cls_token_id, cls_id)
+
+ # Create a new mapping from the special tokens defined in the original tokenizer
+ special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
+ special_tokens_list.remove("additional_special_tokens")
+ special_tokens_map = {}
+ for token in special_tokens_list:
+ # Get the private one to avoid unnecessary warnings.
+ if getattr(tokenizer, f"_{token}") is not None:
+ special_token = getattr(tokenizer, token)
+ special_tokens_map[special_token] = f"{special_token}a"
+
+ # Train new tokenizer
+ new_tokenizer = tokenizer.train_new_from_iterator(
+ SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
+ )
+
+ # Check the changes
+ for token in special_tokens_list:
+ # Get the private one to avoid unnecessary warnings.
+ if getattr(tokenizer, f"_{token}") is None:
+ continue
+ special_token = getattr(tokenizer, token)
+ if special_token in special_tokens_map:
+ new_special_token = getattr(new_tokenizer, token)
+ self.assertEqual(special_tokens_map[special_token], new_special_token)
+
+ new_id = new_tokenizer.get_vocab()[new_special_token]
+ self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
+
+ # Check if the AddedToken / string format has been kept
+ for special_token in tokenizer.all_special_tokens_extended:
+ if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
+ # The special token must appear identically in the list of the new tokenizer.
+ self.assertTrue(
+ special_token in new_tokenizer.all_special_tokens_extended,
+ f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
+ )
+ elif isinstance(special_token, AddedToken):
+ # The special token must appear in the list of the new tokenizer as an object of type AddedToken with
+ # the same parameters as the old AddedToken except the content that the user has requested to change.
+ special_token_str = special_token.content
+ new_special_token_str = special_tokens_map[special_token_str]
+
+ find = False
+ for candidate in new_tokenizer.all_special_tokens_extended:
+ if (
+ isinstance(candidate, AddedToken)
+ and candidate.content == new_special_token_str
+ and candidate.lstrip == special_token.lstrip
+ and candidate.rstrip == special_token.rstrip
+ and candidate.normalized == special_token.normalized
+ and candidate.single_word == special_token.single_word
+ ):
+ find = True
+ break
+ self.assertTrue(
+ find,
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
+ )
+ elif special_token not in special_tokens_map:
+ # The special token must appear identically in the list of the new tokenizer.
+ self.assertTrue(
+ special_token in new_tokenizer.all_special_tokens_extended,
+ f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
+ )
+
+ else:
+ # The special token must appear in the list of the new tokenizer as an object of type string.
+ self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended)
+
+ # Test we can use the new tokenizer with something not seen during training
+ words = [["this", "is"], ["hello", "š¤"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]]
+ inputs = new_tokenizer(words, boxes=boxes)
+ self.assertEqual(len(inputs["input_ids"]), 2)
+ decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
+ expected_result = " this is"
+
+ if tokenizer.backend_tokenizer.normalizer is not None:
+ expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
+ self.assertEqual(expected_result, decoded_input)
+
+ def test_prepare_for_model(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ # only test prepare_for_model for the slow tokenizer
+ if tokenizer.__class__.__name__ == "LayoutLMv3TokenizerFast":
+ continue
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ prepared_input_dict = tokenizer.prepare_for_model(words, boxes=boxes, add_special_tokens=True)
+
+ input_dict = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True)
+
+ self.assertEqual(input_dict, prepared_input_dict)
+
+ def test_padding_different_model_input_name(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
+ pad_token_id = tokenizer_p.pad_token_id
+
+ words, boxes = self.get_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.batch_encode_plus(words, boxes=boxes)
+
+ # rename encoded batch to "inputs"
+ input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]]
+ del input_r[tokenizer_r.model_input_names[0]]
+
+ input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]]
+ del input_p[tokenizer_p.model_input_names[0]]
+
+ # Renaming `input_ids` to `inputs`
+ tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:]
+ tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:]
+
+ input_r = tokenizer_r.pad(input_r, padding="longest")
+ input_p = tokenizer_r.pad(input_p, padding="longest")
+
+ max_length = len(input_p["inputs"][0])
+ self.assert_batch_padded_input_match(
+ input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs"
+ )
+
+ def test_batch_encode_dynamic_overflowing(self):
+ """
+ When calling batch_encode with multiple sequences, it can return different number of
+ overflowing encoding for each sequence:
+ [
+ Sequence 1: [Encoding 1, Encoding 2],
+ Sequence 2: [Encoding 1],
+ Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
+ ]
+ This needs to be padded so that it can represented as a tensor
+ """
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"):
+
+ if is_torch_available():
+ returned_tensor = "pt"
+ elif is_tf_available():
+ returned_tensor = "tf"
+ else:
+ returned_tensor = "jax"
+
+ # Single example
+ words = ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"]
+ boxes = [[i, i, i, i] for i in range(len(words))]
+ tokens = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=6,
+ padding=True,
+ truncation=True,
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
+ )
+
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ if key != "bbox":
+ self.assertEqual(len(tokens[key].shape), 2)
+ else:
+ self.assertEqual(len(tokens[key].shape), 3)
+
+ # Batch of examples
+ # For these 2 examples, 3 training examples will be created
+ words_batched = [
+ ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"],
+ ["Very", "tiny", "input"],
+ ]
+ boxes_batched = [[[i, i, i, i] for i in range(len(words_item))] for words_item in words_batched]
+ tokens = tokenizer.batch_encode_plus(
+ words_batched,
+ boxes=boxes_batched,
+ max_length=6,
+ padding=True,
+ truncation="only_first",
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
+ )
+
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ if key != "bbox":
+ self.assertEqual(len(tokens[key].shape), 2)
+ self.assertEqual(tokens[key].shape[-1], 6)
+ else:
+ self.assertEqual(len(tokens[key].shape), 3)
+ self.assertEqual(tokens[key].shape[-1], 4)
+
+ @unittest.skip("TO DO: overwrite this very extensive test.")
+ def test_alignement_methods(self):
+ pass
+
+ def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5):
+ toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
+ toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
+ toks = list(
+ filter(
+ lambda t: [t[0]]
+ == tokenizer.encode(t[1].split(" "), boxes=len(t[1]) * [[1, 1, 1, 1]], add_special_tokens=False),
+ toks,
+ )
+ )
+ if max_length is not None and len(toks) > max_length:
+ toks = toks[:max_length]
+ if min_length is not None and len(toks) < min_length and len(toks) > 0:
+ while len(toks) < min_length:
+ toks = toks + toks
+ # toks_str = [t[1] for t in toks]
+ toks_ids = [t[0] for t in toks]
+
+ # Ensure consistency
+ output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
+ if " " not in output_txt and len(toks_ids) > 1:
+ output_txt = (
+ tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ + " "
+ + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
+ )
+ if with_prefix_space:
+ output_txt = " " + output_txt
+ words = output_txt.split(" ")
+ boxes = [[i, i, i, i] for i in range(len(words))]
+ output_ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ return words, boxes, output_ids
+
+ def test_added_token_with_space_before(self):
+
+ tokenizer_s = self.get_tokenizer()
+ tokenizer_f = self.get_rust_tokenizer()
+
+ tokens_to_add = ["AAA", "bbb"]
+
+ words_with_space = [f" {token}" for token in tokens_to_add + tokenizer_s.unique_no_split_tokens]
+ words_without_space = tokens_to_add + tokenizer_s.unique_no_split_tokens
+ boxes = [[i, i, i, i] for i in range(len(words_with_space))]
+
+ tokens_to_add_formated = [
+ AddedToken(token, rstrip=True, lstrip=True, single_word=False) for token in tokens_to_add
+ ]
+ tokenizer_s.add_tokens(tokens_to_add_formated)
+ tokenizer_f.add_tokens(tokens_to_add_formated)
+
+ ids_s = tokenizer_s(words_with_space, boxes=boxes).input_ids
+ ids_f = tokenizer_f(words_with_space, boxes=boxes).input_ids
+
+ tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s)
+ tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f)
+
+ ids_s = tokenizer_s(words_without_space, boxes=boxes).input_ids
+ ids_f = tokenizer_f(words_without_space, boxes=boxes).input_ids
+
+ tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s)
+ tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f)
+
+ self.assertEqual(tokens_s, tokens_f)
+
+ def test_maximum_encoding_length_pair_input(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Build a sequence from our model's vocabulary
+ stride = 2
+ seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
+ question_0 = " ".join(map(str, seq_0))
+ if len(ids) <= 2 + stride:
+ seq_0 = (seq_0 + " ") * (2 + stride)
+ ids = None
+
+ seq0_tokens = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
+ seq0_input_ids = seq0_tokens["input_ids"]
+
+ self.assertGreater(len(seq0_input_ids), 2 + stride)
+ question_1 = "This is another sentence to be encoded."
+ seq_1 = ["what", "a", "weird", "test", "weirdly", "weird"]
+ boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
+ seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ if abs(len(seq0_input_ids) - len(seq1_tokens["input_ids"])) <= 2:
+ seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"]
+ seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False)
+ seq_1 = seq_1.split(" ")
+ boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
+ seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ seq1_input_ids = seq1_tokens["input_ids"]
+
+ self.assertGreater(len(seq1_input_ids), 2 + stride)
+
+ smallest = seq1_input_ids if len(seq0_input_ids) > len(seq1_input_ids) else seq0_input_ids
+
+ # We are not using the special tokens - a bit too hard to test all the tokenizers with this
+ # TODO try this again later
+ sequence = tokenizer(
+ question_0, seq_1, boxes=boxes_1, add_special_tokens=False
+ ) # , add_prefix_space=False)
+
+ # Test with max model input length
+ model_max_length = tokenizer.model_max_length
+ self.assertEqual(model_max_length, 100)
+ seq_2 = seq_0 * model_max_length
+ question_2 = " ".join(map(str, seq_2))
+ boxes_2 = boxes_0 * model_max_length
+ self.assertGreater(len(seq_2), model_max_length)
+
+ sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length1 = len(sequence1["input_ids"])
+ sequence2 = tokenizer(question_2, seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length2 = len(sequence2["input_ids"])
+ self.assertLess(total_length1, model_max_length, "Issue with the testing sequence, please update it.")
+ self.assertGreater(
+ total_length2, model_max_length, "Issue with the testing sequence, please update it."
+ )
+
+ # Simple
+ padding_strategies = (
+ [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
+ )
+ for padding_state in padding_strategies:
+ with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
+ for truncation_state in [True, "longest_first", "only_first"]:
+ with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
+ output = tokenizer(
+ question_2,
+ seq_1,
+ boxes=boxes_1,
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [question_2],
+ [seq_1],
+ boxes=[boxes_1],
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple
+ output = tokenizer(
+ question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation="only_second"
+ )
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation="only_second"
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple with no truncation
+ # Reset warnings
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(
+ question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation=False
+ )
+ self.assertNotEqual(len(output["input_ids"]), model_max_length)
+ self.assertNotEqual(len(output["bbox"]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(
+ [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation=False
+ )
+ self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertNotEqual(len(output["bbox"][0]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+ # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
+ truncated_first_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][:-2]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
+ )
+ truncated_second_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][:-2]
+ )
+ truncated_longest_sequence = (
+ truncated_first_sequence
+ if len(seq0_input_ids) > len(seq1_input_ids)
+ else truncated_second_sequence
+ )
+
+ overflow_first_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][-(2 + stride) :]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
+ )
+ overflow_second_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][-(2 + stride) :]
+ )
+ overflow_longest_sequence = (
+ overflow_first_sequence if len(seq0_input_ids) > len(seq1_input_ids) else overflow_second_sequence
+ )
+
+ bbox_first = [[0, 0, 0, 0]] * (len(seq0_input_ids) - 2)
+ bbox_first_sequence = bbox_first + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"]
+ overflowing_token_bbox_first_sequence_slow = [[0, 0, 0, 0]] * (2 + stride)
+ overflowing_token_bbox_first_sequence_fast = [[0, 0, 0, 0]] * (2 + stride) + tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"]
+
+ bbox_second = [[0, 0, 0, 0]] * len(seq0_input_ids)
+ bbox_second_sequence = (
+ bbox_second + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"][:-2]
+ )
+ overflowing_token_bbox_second_sequence_slow = tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"][-(2 + stride) :]
+ overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq0_input_ids) + tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"][-(2 + stride) :]
+
+ bbox_longest_sequence = (
+ bbox_first_sequence if len(seq0_tokens) > len(seq1_tokens) else bbox_second_sequence
+ )
+ overflowing_token_bbox_longest_sequence_fast = (
+ overflowing_token_bbox_first_sequence_fast
+ if len(seq0_tokens) > len(seq1_tokens)
+ else overflowing_token_bbox_second_sequence_fast
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="longest_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ bbox = information["bbox"][0]
+ overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_longest_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_tokens, overflow_longest_sequence)
+ self.assertEqual(bbox, bbox_longest_sequence)
+
+ self.assertEqual(len(overflowing_bbox), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
+ else:
+ # No overflowing tokens when using 'longest' in python tokenizers
+ with self.assertRaises(ValueError) as context:
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="longest_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ self.assertTrue(
+ context.exception.args[0].startswith(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ bbox = information["bbox"][0]
+ overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_longest_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_tokens, overflow_longest_sequence)
+ self.assertEqual(bbox, bbox_longest_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
+ else:
+ # No overflowing tokens when using 'longest' in python tokenizers
+ with self.assertRaises(ValueError) as context:
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ self.assertTrue(
+ context.exception.args[0].startswith(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+ )
+
+ information_first_truncated = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="only_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information_first_truncated["input_ids"][0]
+ overflowing_tokens = information_first_truncated["input_ids"][1]
+ bbox = information_first_truncated["bbox"][0]
+ overflowing_bbox = information_first_truncated["bbox"][0]
+ self.assertEqual(len(information_first_truncated["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_first_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_input_ids))
+ self.assertEqual(overflowing_tokens, overflow_first_sequence)
+ self.assertEqual(bbox, bbox_first_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_fast)
+ else:
+ truncated_sequence = information_first_truncated["input_ids"]
+ overflowing_tokens = information_first_truncated["overflowing_tokens"]
+ overflowing_bbox = information_first_truncated["overflowing_token_boxes"]
+ bbox = information_first_truncated["bbox"]
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_first_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, seq0_input_ids[-(2 + stride) :])
+ self.assertEqual(bbox, bbox_first_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_slow)
+
+ information_second_truncated = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="only_second",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information_second_truncated["input_ids"][0]
+ overflowing_tokens = information_second_truncated["input_ids"][1]
+ bbox = information_second_truncated["bbox"][0]
+ overflowing_bbox = information_second_truncated["bbox"][1]
+
+ self.assertEqual(len(information_second_truncated["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_second_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_input_ids))
+ self.assertEqual(overflowing_tokens, overflow_second_sequence)
+ self.assertEqual(bbox, bbox_second_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_fast)
+ else:
+ truncated_sequence = information_second_truncated["input_ids"]
+ overflowing_tokens = information_second_truncated["overflowing_tokens"]
+ bbox = information_second_truncated["bbox"]
+ overflowing_bbox = information_second_truncated["overflowing_token_boxes"]
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_second_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, seq1_input_ids[-(2 + stride) :])
+ self.assertEqual(bbox, bbox_second_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_slow)
+
+ def test_maximum_encoding_length_single_input(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
+
+ sequence = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
+ total_length = len(sequence["input_ids"])
+
+ self.assertGreater(
+ total_length, 4, "Issue with the testing sequence, please update it, it's too short"
+ )
+
+ # Test with max model input length
+ model_max_length = tokenizer.model_max_length
+ self.assertEqual(model_max_length, 100)
+ seq_1 = seq_0 * model_max_length
+ boxes_1 = boxes_0 * model_max_length
+ sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length1 = len(sequence1["input_ids"])
+ self.assertGreater(
+ total_length1,
+ model_max_length,
+ "Issue with the testing sequence, please update it, it's too short",
+ )
+
+ # Simple
+ padding_strategies = (
+ [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
+ )
+ for padding_state in padding_strategies:
+ with self.subTest(f"Padding: {padding_state}"):
+ for truncation_state in [True, "longest_first", "only_first"]:
+ with self.subTest(f"Truncation: {truncation_state}"):
+ output = tokenizer(
+ seq_1,
+ boxes=boxes_1,
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [seq_1],
+ boxes=[boxes_1],
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple with no truncation
+ # Reset warnings
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(seq_1, boxes=boxes_1, padding=padding_state, truncation=False)
+ self.assertNotEqual(len(output["input_ids"]), model_max_length)
+ self.assertNotEqual(len(output["bbox"]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer([seq_1], boxes=[boxes_1], padding=padding_state, truncation=False)
+ self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertNotEqual(len(output["bbox"][0]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+ # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
+ stride = 2
+ information = tokenizer(
+ seq_0,
+ boxes=boxes_0,
+ max_length=total_length - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ # bbox = information["bbox"][0]
+ # overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), total_length - 2)
+ self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
+
+ # self.assertEqual(bbox, sequence["bbox"][:-2])
+ # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
+ else:
+ truncated_sequence = information["input_ids"]
+ overflowing_tokens = information["overflowing_tokens"]
+ # bbox = information["bbox"]
+ # overflowing_bbox = information["overflowing_token_boxes"]
+ self.assertEqual(len(truncated_sequence), total_length - 2)
+ self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
+ # self.assertEqual(bbox, sequence["bbox"][:-2])
+ # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
+
+ @unittest.skip("LayoutLMv3 tokenizer requires boxes besides sequences.")
+ def test_pretokenized_inputs(self):
+ pass
+
+ @unittest.skip("LayoutLMv3 tokenizer always expects pretokenized inputs.")
+ def test_compare_pretokenized_inputs(self):
+ pass
+
+ @unittest.skip("LayoutLMv3 fast tokenizer does not support prepare_for_model")
+ def test_compare_prepare_for_model(self):
+ pass
+
+ @slow
+ def test_only_label_first_subword(self):
+ words = ["hello", "niels"]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+ word_labels = [0, 1]
+
+ # test slow tokenizer
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
+
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained(
+ "microsoft/layoutlmv3-base",
+ only_label_first_subword=False,
+ add_visual_labels=False,
+ )
+ encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
+
+ # test fast tokenizer
+ tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
+
+ tokenizer_r = LayoutLMv3Tokenizer.from_pretrained(
+ "microsoft/layoutlmv3-base",
+ only_label_first_subword=False,
+ add_visual_labels=False,
+ )
+ encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
+
+ @slow
+ def test_layoutlmv3_integration_test(self):
+
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
+ tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
+
+ # There are 3 cases:
+ # CASE 1: document image classification (training + inference), document image token classification (inference),
+ # in which case only words and normalized bounding boxes are provided to the tokenizer
+ # CASE 2: document image token classification (training),
+ # in which case one also provides word labels to the tokenizer
+ # CASE 3: document image visual question answering (inference),
+ # in which case one also provides a question to the tokenizer
+
+ # We need to test all 3 cases both on batched and non-batched inputs.
+
+ # CASE 1: not batched
+ words, boxes = self.get_words_and_boxes()
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # CASE 1: batched
+ words, boxes = self.get_words_and_boxes_batch()
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # CASE 2: not batched
+ words, boxes = self.get_words_and_boxes()
+ word_labels = [1, 2]
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'labels': [-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 2: batched
+ words, boxes = self.get_words_and_boxes_batch()
+ word_labels = [[1, 2], [2, 46]]
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'labels': [[-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [-100, 2, 46, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 3: not batched
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(question, words, boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(question, words, boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 3: batched
+ questions, words, boxes = self.get_question_words_and_boxes_batch()
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 141, 16, 37, 373, 116, 2, 2, 13964, 795, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [256, 38, 330, 58], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(questions, words, boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(questions, words, boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ @unittest.skip("Doesn't support another framework than PyTorch")
+ def test_np_encode_plus_sent_to_model(self):
+ pass
diff --git a/tests/models/layoutxlm/test_processor_layoutxlm.py b/tests/models/layoutxlm/test_processor_layoutxlm.py
index d208097e6c80..d0d7eec28a34 100644
--- a/tests/models/layoutxlm/test_processor_layoutxlm.py
+++ b/tests/models/layoutxlm/test_processor_layoutxlm.py
@@ -22,7 +22,6 @@
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
from transformers.testing_utils import (
- get_tests_dir,
require_pytesseract,
require_sentencepiece,
require_tokenizers,
@@ -38,9 +37,6 @@
from transformers import LayoutLMv2FeatureExtractor, LayoutXLMProcessor
-SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
-
-
@require_pytesseract
@require_sentencepiece
@require_tokenizers
@@ -60,11 +56,14 @@ def setUp(self):
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
+ # taken from `test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_save_pretrained`
+ self.tokenizer_pretrained_name = "hf-internal-testing/tiny-random-layoutxlm"
+
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
- return self.tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)
+ return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
- return self.rust_tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)
+ return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
@@ -177,10 +176,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -198,10 +198,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITCās Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITCās value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITCās brands national assets, adding to Indiaās competitiveness. It is ITCās aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
@slow
@@ -228,7 +229,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = " hello world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -243,7 +244,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = " hello world"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -282,7 +283,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = " weirdly world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify labels
@@ -304,7 +305,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = " my name is niels"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -344,10 +345,11 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -362,8 +364,9 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -396,7 +399,7 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = " What's his name? hello world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -412,11 +415,11 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = " How old is he? hello world"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
expected_decoding = " what's the time my name is niels"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
index 561e87e77234..68aba50ecaf4 100644
--- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py
+++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
@@ -1543,11 +1543,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
diff --git a/tests/models/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py
index 9d3d090ab17d..e7dc31838aa3 100644
--- a/tests/models/led/test_modeling_led.py
+++ b/tests/models/led/test_modeling_led.py
@@ -163,6 +163,7 @@ def get_config(self):
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
+ config.vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@@ -528,9 +529,26 @@ def test_seq_to_seq_generation(self):
no_repeat_ngram_size=3,
)
- EXPECTED_LEP = " the physics of @xmath0-boson will again play the central role in the frontier of particle physics if the gigaz option of the international linear collider ( ilc ) can be realized in its first phase. \n the expected sensitivity to the branching ratio of the rare decays, especially its exotic or rare processes, should be investigated comprehensively to evaluate their potential in probing new physics. in this work \n, we extend the previous studies of these decays to some new models and investigate the decays altogether. we are motivated by some recent studies on the singlet extension of the mssm, such as the next - to - minimal supersymmetric standard model ( nmssm ) @xcite and the nearly - minimal - supersymmetry - standard - model(nmssm)@xcite, where a light cp - odd higgs boson with singlet - dominant component may naturally arise from the spontaneous breaking of some approximate global symmetry. # 1#2#3#4#5#6#7#8#9#10#11#12 "
+ EXPECTED_LEP = (
+ " the physics of @xmath0-boson will again play the central role in the frontier of particle physics if the"
+ " gigaz option of the international linear collider ( ilc ) can be realized in its first phase. \n the"
+ " expected sensitivity to the branching ratio of the rare decays, especially its exotic or rare processes,"
+ " should be investigated comprehensively to evaluate their potential in probing new physics. in this work"
+ " \n, we extend the previous studies of these decays to some new models and investigate the decays"
+ " altogether. we are motivated by some recent studies on the singlet extension of the mssm, such as the"
+ " next - to - minimal supersymmetric standard model ( nmssm ) @xcite and the nearly - minimal -"
+ " supersymmetry - standard - model(nmssm)@xcite, where a light cp - odd higgs boson with singlet -"
+ " dominant component may naturally arise from the spontaneous breaking of some approximate global"
+ " symmetry. # 1#2#3#4#5#6#7#8#9#10#11#12 "
+ )
- EXPECTED_MAGNET = " the recent experiment in the surface states of the topological insulator bi@xmath0se @xmath1, however, reported that a large positive magnetoresistance becomes very linear in perpendicular magnetic field even in an opposite situation where the carrier sheet density is high that all electrons occupy more than one landau levels. \n it is striking that this observation is in conflict with abrikosov s model and also with the classical parish - littlewood model. "
+ EXPECTED_MAGNET = (
+ " the recent experiment in the surface states of the topological insulator bi@xmath0se @xmath1, however,"
+ " reported that a large positive magnetoresistance becomes very linear in perpendicular magnetic field"
+ " even in an opposite situation where the carrier sheet density is high that all electrons occupy more"
+ " than one landau levels. \n it is striking that this observation is in conflict with abrikosov s model"
+ " and also with the classical parish - littlewood model. "
+ )
generated = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
diff --git a/tests/models/led/test_modeling_tf_led.py b/tests/models/led/test_modeling_tf_led.py
index 8075d071e662..dfdb66606faf 100644
--- a/tests/models/led/test_modeling_tf_led.py
+++ b/tests/models/led/test_modeling_tf_led.py
@@ -17,7 +17,7 @@
import unittest
from transformers import LEDConfig, is_tf_available
-from transformers.testing_utils import require_tf, slow
+from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@@ -365,8 +365,8 @@ def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_generate_with_headmasking(self):
diff --git a/tests/models/levit/__init__.py b/tests/models/levit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/levit/test_feature_extraction_levit.py b/tests/models/levit/test_feature_extraction_levit.py
new file mode 100644
index 000000000000..98a704b97a62
--- /dev/null
+++ b/tests/models/levit/test_feature_extraction_levit.py
@@ -0,0 +1,195 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LevitFeatureExtractor
+
+
+class LevitFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "do_center_crop": self.do_center_crop,
+ "size": self.size,
+ }
+
+
+@require_torch
+@require_vision
+class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = LevitFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = LevitFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
diff --git a/tests/models/levit/test_modeling_levit.py b/tests/models/levit/test_modeling_levit.py
new file mode 100644
index 000000000000..725b279fd02f
--- /dev/null
+++ b/tests/models/levit/test_modeling_levit.py
@@ -0,0 +1,427 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch LeViT model. """
+
+
+import inspect
+import unittest
+import warnings
+from math import ceil, floor
+
+from transformers import LevitConfig
+from transformers.file_utils import cached_property, is_torch_available, is_vision_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ MODEL_MAPPING,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ )
+ from transformers.models.levit.modeling_levit import LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LevitFeatureExtractor
+
+
+class LevitConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
+
+
+class LevitModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=64,
+ num_channels=3,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ patch_size=16,
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 6, 8],
+ depths=[2, 3, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ mlp_ratio=[2, 2, 2],
+ attention_ratio=[2, 2, 2],
+ initializer_range=0.02,
+ is_training=True,
+ use_labels=True,
+ num_labels=2, # Check
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.hidden_sizes = hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.depths = depths
+ self.key_dim = key_dim
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size
+ self.attention_ratio = attention_ratio
+ self.mlp_ratio = mlp_ratio
+ self.initializer_range = initializer_range
+ self.down_ops = [
+ ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
+ ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
+ ]
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.initializer_range = initializer_range
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return LevitConfig(
+ image_size=self.image_size,
+ num_channels=self.num_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ patch_size=self.patch_size,
+ hidden_sizes=self.hidden_sizes,
+ num_attention_heads=self.num_attention_heads,
+ depths=self.depths,
+ key_dim=self.key_dim,
+ drop_path_rate=self.drop_path_rate,
+ mlp_ratio=self.mlp_ratio,
+ attention_ratio=self.attention_ratio,
+ initializer_range=self.initializer_range,
+ down_ops=self.down_ops,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = LevitModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ image_size = (self.image_size, self.image_size)
+ height, width = image_size[0], image_size[1]
+ for _ in range(4):
+ height = floor(((height + 2 * self.padding - self.kernel_size) / self.stride) + 1)
+ width = floor(((width + 2 * self.padding - self.kernel_size) / self.stride) + 1)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, ceil(height / 4) * ceil(width / 4), self.hidden_sizes[-1]),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = LevitForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class LevitModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Levit does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (LevitModel, LevitForImageClassification, LevitForImageClassificationWithTeacher)
+ if is_torch_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = LevitModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LevitConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="Levit does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Levit does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="Levit does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = len(self.model_tester.depths) + 1
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ height, width = image_size[0], image_size[1]
+ for _ in range(4):
+ height = floor(
+ (
+ (height + 2 * self.model_tester.padding - self.model_tester.kernel_size)
+ / self.model_tester.stride
+ )
+ + 1
+ )
+ width = floor(
+ (
+ (width + 2 * self.model_tester.padding - self.model_tester.kernel_size)
+ / self.model_tester.stride
+ )
+ + 1
+ )
+ # verify the first hidden states (first block)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [
+ height * width,
+ self.model_tester.hidden_sizes[0],
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class.__name__ == "LevitForImageClassificationWithTeacher":
+ del inputs_dict["labels"]
+
+ return inputs_dict
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ # special case for LevitForImageClassificationWithTeacher model
+ def test_training(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ # LevitForImageClassificationWithTeacher supports inference-only
+ if (
+ model_class in get_values(MODEL_MAPPING)
+ or model_class.__name__ == "LevitForImageClassificationWithTeacher"
+ ):
+ continue
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ if not self.model_tester.is_training:
+ return
+
+ config.use_cache = False
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
+ continue
+ # LevitForImageClassificationWithTeacher supports inference-only
+ if model_class.__name__ == "LevitForImageClassificationWithTeacher":
+ continue
+ model = model_class(config)
+ model.gradient_checkpointing_enable()
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_problem_types(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ problem_types = [
+ {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
+ {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
+ {"title": "regression", "num_labels": 1, "dtype": torch.float},
+ ]
+
+ for model_class in self.all_model_classes:
+ if (
+ model_class
+ not in [
+ *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
+ ]
+ or model_class.__name__ == "LevitForImageClassificationWithTeacher"
+ ):
+ continue
+
+ for problem_type in problem_types:
+ with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
+
+ config.problem_type = problem_type["title"]
+ config.num_labels = problem_type["num_labels"]
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+
+ if problem_type["num_labels"] > 1:
+ inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
+
+ inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
+
+ # This tests that we do not trigger the warning form PyTorch "Using a target size that is different
+ # to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure
+ # they have the same size." which is a symptom something in wrong for the regression problem.
+ # See https://github.com/huggingface/transformers/issues/11780
+ with warnings.catch_warnings(record=True) as warning_list:
+ loss = model(**inputs).loss
+ for w in warning_list:
+ if "Using a target size that is different to the input size" in str(w.message):
+ raise ValueError(
+ f"Something is going wrong in the regression problem: intercepted {w.message}"
+ )
+
+ loss.backward()
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LevitModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class LevitModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return LevitFeatureExtractor.from_pretrained(LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = LevitForImageClassificationWithTeacher.from_pretrained(LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(
+ torch_device
+ )
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([1.0448, -0.3745, -1.8317]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/longformer/test_modeling_longformer.py b/tests/models/longformer/test_modeling_longformer.py
index fd10b14eaead..c1839d67d36c 100644
--- a/tests/models/longformer/test_modeling_longformer.py
+++ b/tests/models/longformer/test_modeling_longformer.py
@@ -113,6 +113,11 @@ def get_config(self):
attention_window=self.attention_window,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
diff --git a/tests/models/longformer/test_modeling_tf_longformer.py b/tests/models/longformer/test_modeling_tf_longformer.py
index 12c19e566e95..cc62bb6caf70 100644
--- a/tests/models/longformer/test_modeling_tf_longformer.py
+++ b/tests/models/longformer/test_modeling_tf_longformer.py
@@ -17,7 +17,7 @@
import unittest
from transformers import is_tf_available
-from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@@ -326,8 +326,8 @@ def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_xla_mode(self):
diff --git a/tests/models/longt5/__init__.py b/tests/models/longt5/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/longt5/test_modeling_flax_longt5.py b/tests/models/longt5/test_modeling_flax_longt5.py
new file mode 100644
index 000000000000..9406e292d177
--- /dev/null
+++ b/tests/models/longt5/test_modeling_flax_longt5.py
@@ -0,0 +1,757 @@
+# coding=utf-8
+# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import tempfile
+import unittest
+
+import numpy as np
+
+import transformers
+from transformers import is_flax_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import (
+ is_pt_flax_cross_test,
+ require_flax,
+ require_sentencepiece,
+ require_tokenizers,
+ slow,
+)
+
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+
+
+if is_flax_available():
+ import os
+
+ # The slow tests are often failing with OOM error on GPU
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+
+ import jax
+ import jax.numpy as jnp
+ from flax.core.frozen_dict import unfreeze
+ from flax.traverse_util import flatten_dict
+ from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
+ from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
+ from transformers.models.longt5.modeling_flax_longt5 import (
+ FlaxLongT5ForConditionalGeneration,
+ FlaxLongT5Model,
+ shift_tokens_right,
+ )
+
+
+class FlaxLongT5ModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ decoder_seq_length=9,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ is_training=True,
+ use_attention_mask=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ eos_token_id=1,
+ pad_token_id=0,
+ decoder_start_token_id=0,
+ scope=None,
+ decoder_layers=None,
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.decoder_seq_length = decoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.scope = None
+ self.decoder_layers = decoder_layers
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+ decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ decoder_attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+ decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ config = LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ return (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ):
+ model = FlaxLongT5Model(config=config)
+ result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+ result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ decoder_output = result.last_hidden_state
+ encoder_output = result.encoder_last_hidden_state
+
+ self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
+ self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size))
+
+ def check_use_cache_forward_with_attn_mask(
+ self,
+ model_class_name,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ):
+ max_decoder_length = 20
+ model = model_class_name(config)
+
+ encoder_outputs = model.encode(input_ids)
+
+ # prevent fully zero'd out attention mask
+ decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
+
+ decoder_attention_mask_cache = jnp.concatenate(
+ [
+ decoder_attention_mask,
+ jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
+ ],
+ axis=-1,
+ )
+
+ past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
+
+ outputs_cache = model.decode(
+ decoder_input_ids[:, :-1],
+ encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask_cache,
+ past_key_values=past_key_values,
+ )
+ outputs_cache_next = model.decode(
+ decoder_input_ids[:, -1:],
+ encoder_outputs,
+ past_key_values=outputs_cache.past_key_values,
+ decoder_attention_mask=decoder_attention_mask_cache,
+ )
+
+ outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_flax
+class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
+ all_generative_model_classes = (FlaxLongT5ForConditionalGeneration,) if is_flax_available() else ()
+ is_encoder_decoder = True
+
+ def setUp(self):
+ self.model_tester = FlaxLongT5ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_v1_1(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ # check that gated gelu feed forward and different word embeddings work
+ config = config_and_inputs[0]
+ config.tie_word_embeddings = False
+ config.feed_forward_proj = "gated-gelu"
+ self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
+
+ def test_use_cache_forward_with_attn_mask(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
+
+ def test_encode(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ with self.subTest(model_class.__name__):
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+
+ @jax.jit
+ def encode_jitted(input_ids, attention_mask=None, **kwargs):
+ return model.encode(input_ids=input_ids, attention_mask=attention_mask)
+
+ with self.subTest("JIT Enabled"):
+ jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ with self.subTest("JIT Disabled"):
+ with jax.disable_jit():
+ outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ self.assertEqual(len(outputs), len(jitted_outputs))
+ for jitted_output, output in zip(jitted_outputs, outputs):
+ self.assertEqual(jitted_output.shape, output.shape)
+
+ def test_decode(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ with self.subTest(model_class.__name__):
+ model = model_class(config)
+ encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
+
+ prepared_inputs_dict = {
+ "decoder_input_ids": inputs_dict["decoder_input_ids"],
+ "decoder_attention_mask": inputs_dict["decoder_attention_mask"],
+ "encoder_outputs": encoder_outputs,
+ }
+
+ @jax.jit
+ def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
+ return model.decode(
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ )
+
+ with self.subTest("JIT Enabled"):
+ jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
+
+ with self.subTest("JIT Disabled"):
+ with jax.disable_jit():
+ outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
+
+ self.assertEqual(len(outputs), len(jitted_outputs))
+ for jitted_output, output in zip(jitted_outputs, outputs):
+ self.assertEqual(jitted_output.shape, output.shape)
+
+ def test_shift_right(self):
+ decoder_start_token_id = 0
+ pad_token_id = 1
+ labels = np.arange(2, 102).reshape(5, 20)
+ labels[:2, 15:] = -100
+
+ decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
+ np_decoder_input_ids = np.array(decoder_input_ids)
+
+ padded_slice = np_decoder_input_ids[:2, (15 + 1) :]
+ self.assertTrue((padded_slice == 1).all())
+
+ not_padded_slice = np_decoder_input_ids[2:, 1:]
+ rolled_labels = np.roll(labels[2:], 1)[:, 1:]
+ self.assertTrue((not_padded_slice == rolled_labels).all())
+ self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())
+
+ # overwrite since special base model prefix is used
+ def test_save_load_from_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ def test_save_load_to_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_length = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ block_len = getattr(self.model_tester, "block_len", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_from_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # save pt model
+ pt_model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_bf16_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ model.params = model.to_bf16(model.params)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+
+class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
+ def setUp(self):
+ self.model_tester = FlaxLongT5ModelTester(self, encoder_attention_type="transient-global")
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_length = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_len = encoder_seq_length // global_block_size
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+
+@require_sentencepiece
+@require_tokenizers
+@require_flax
+class FlaxLongT5ModelIntegrationTests(unittest.TestCase):
+ model_path = "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
+
+ def expected_summary(self):
+ return [
+ "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
+ " developing world . it provides an excellent resolution for visualization of the coronary arteries for"
+ " catheter - based or operating interventions . although the association of this technique with major"
+ " complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
+ " and noncardiac complications . computed tomography coronary angiography is a promising technique for the"
+ " evaluation of cad noninvasively . it assesses disease within the coronary artery and provides"
+ " qualitative and quantitative information about nonobstructive atherosclerotic plaque"
+ ]
+
+ @slow
+ def test_summarization(self):
+ model = FlaxLongT5ForConditionalGeneration.from_pretrained(self.model_path)
+ tok = AutoTokenizer.from_pretrained(self.model_path)
+
+ ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
+ although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
+ a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
+ wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
+ 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
+ ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
+ aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
+ roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
+ ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
+ duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
+ nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
+ ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
+ ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
+ y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
+ re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
+ ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
+ aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
+ gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
+ ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
+ hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
+ lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
+ hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
+ e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
+ used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
+ agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
+ positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
+ er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
+ lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
+ irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
+ giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
+ software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
+ ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
+ s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
+ icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
+ confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
+ in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
+ double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
+ dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
+ coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
+ iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
+ who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
+ tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
+ 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
+ patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
+ or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
+ in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
+ nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
+ retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
+ lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
+ iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
+ a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
+ could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
+ ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
+ s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
+ rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
+
+ dct = tok(
+ [ARTICLE],
+ max_length=1024,
+ padding="max_length",
+ truncation=True,
+ return_tensors="np",
+ )
+
+ hypotheses_batch = model.generate(
+ **dct,
+ num_beams=4,
+ length_penalty=2.0,
+ max_length=142,
+ min_length=56,
+ do_sample=False,
+ early_stopping=True,
+ ).sequences
+
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ self.assertListEqual(
+ self.expected_summary(),
+ decoded,
+ )
diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py
new file mode 100644
index 000000000000..61ad68921d9d
--- /dev/null
+++ b/tests/models/longt5/test_modeling_longt5.py
@@ -0,0 +1,1314 @@
+# coding=utf-8
+# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import tempfile
+import unittest
+
+from transformers import LongT5Config, is_torch_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
+from transformers.utils import cached_property
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ AutoTokenizer,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ )
+ from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+class LongT5ModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ decoder_seq_length=9,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ is_training=True,
+ use_attention_mask=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ eos_token_id=1,
+ pad_token_id=0,
+ decoder_start_token_id=0,
+ scope=None,
+ decoder_layers=None,
+ large_model_config_path="google/long-t5-local-large",
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.decoder_seq_length = decoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.scope = None
+ self.decoder_layers = decoder_layers
+ self.large_model_config_path = large_model_config_path
+
+ def get_large_model_config(self):
+ return LongT5Config.from_pretrained(self.large_model_config_path)
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+ decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ decoder_attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+ decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ lm_labels = None
+ if self.use_labels:
+ lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ config = self.get_config()
+
+ return (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ )
+
+ def get_pipeline_config(self):
+ return LongT5Config(
+ vocab_size=166, # longt5 forces 100 extra tokens
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ def get_config(self):
+ return LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ def check_prepare_lm_labels_via_shift_left(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # make sure that lm_labels are correctly padded from the right
+ lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
+
+ # add casaul pad token mask
+ triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
+ lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
+ decoder_input_ids = model._shift_right(lm_labels)
+
+ for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
+ # first item
+ self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
+ if i < decoder_input_ids_slice.shape[-1]:
+ if i < decoder_input_ids.shape[-1] - 1:
+ # items before diagonal
+ self.parent.assertListEqual(
+ decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
+ )
+ # pad items after diagonal
+ if i < decoder_input_ids.shape[-1] - 2:
+ self.parent.assertListEqual(
+ decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
+ )
+ else:
+ # all items after square
+ self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+ result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ decoder_output = result.last_hidden_state
+ decoder_past = result.past_key_values
+ encoder_output = result.encoder_last_hidden_state
+
+ self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
+ self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
+ # There should be `num_layers` key value embeddings stored in decoder_past
+ self.parent.assertEqual(len(decoder_past), config.num_layers)
+ # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
+ self.parent.assertEqual(len(decoder_past[0]), 4)
+
+ def create_and_check_with_lm_head(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval()
+ outputs = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ labels=lm_labels,
+ )
+ self.parent.assertEqual(len(outputs), 4)
+ self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
+ self.parent.assertEqual(outputs["loss"].size(), ())
+
+ def create_and_check_decoder_model_past(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder().to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, use_cache=True)
+ outputs_use_cache_conf = model(input_ids)
+ outputs_no_past = model(input_ids, use_cache=False)
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_decoder_model_attention_mask_past(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder()
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ half_seq_length = input_ids.shape[-1] // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder().to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_generate_with_past_key_values(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval()
+ torch.manual_seed(0)
+ output_without_past_cache = model.generate(
+ input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
+ )
+ torch.manual_seed(0)
+ output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
+ self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
+
+ def create_and_check_encoder_decoder_shared_weights(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ for model_class in [LongT5Model, LongT5ForConditionalGeneration]:
+ torch.manual_seed(0)
+ model = model_class(config=config).to(torch_device).eval()
+ # load state dict copies weights but does not tie them
+ model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
+
+ torch.manual_seed(0)
+ tied_config = copy.deepcopy(config)
+ tied_config.tie_encoder_decoder = True
+ tied_model = model_class(config=tied_config).to(torch_device).eval()
+
+ model_result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ tied_model_result = tied_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ # check that models has less parameters
+ self.parent.assertLess(
+ sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
+ )
+ random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
+
+ # check that outputs are equal
+ self.parent.assertTrue(
+ torch.allclose(
+ model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
+ )
+ )
+
+ # check that outputs after saving and loading are equal
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ tied_model.save_pretrained(tmpdirname)
+ tied_model = model_class.from_pretrained(tmpdirname)
+ tied_model.to(torch_device)
+ tied_model.eval()
+
+ # check that models has less parameters
+ self.parent.assertLess(
+ sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
+ )
+ random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
+
+ tied_model_result = tied_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ # check that outputs are equal
+ self.parent.assertTrue(
+ torch.allclose(
+ model_result[0][0, :, random_slice_idx],
+ tied_model_result[0][0, :, random_slice_idx],
+ atol=1e-4,
+ )
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "use_cache": False,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else ()
+ all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_torchscript = True
+ test_resize_embeddings = True
+ test_model_parallel = False
+ is_encoder_decoder = True
+
+ def setUp(self):
+ self.model_tester = LongT5ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_shift_right(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_with_lm_head(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
+
+ def test_decoder_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
+
+ def test_decoder_model_past_with_attn_mask(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
+
+ def test_decoder_model_past_with_3d_attn_mask(self):
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ) = self.model_tester.prepare_config_and_inputs()
+
+ attention_mask = ids_tensor(
+ [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
+ vocab_size=2,
+ )
+ decoder_attention_mask = ids_tensor(
+ [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
+ vocab_size=2,
+ )
+
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ )
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_generate_with_past_key_values(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs)
+
+ def test_encoder_decoder_shared_weights(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LongT5Model.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ @slow
+ def test_export_to_onnx(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ model = LongT5Model(config_and_inputs[0]).to(torch_device)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ torch.onnx.export(
+ model,
+ (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
+ f"{tmpdirname}/longt5_test.onnx",
+ export_params=True,
+ opset_version=13,
+ input_names=["input_ids", "decoder_input_ids"],
+ )
+
+ def test_generate_with_head_masking(self):
+ attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config = config_and_inputs[0]
+ max_length = config_and_inputs[1].shape[-1] + 3
+ model = LongT5ForConditionalGeneration(config).eval()
+ model.to(torch_device)
+
+ head_masking = {
+ "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
+ "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
+ "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
+ }
+
+ for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
+ head_masks = {name: mask}
+ # Explicitly pass decoder_head_mask as it is required from LONGT5 model when head_mask specified
+ if name == "head_mask":
+ head_masks["decoder_head_mask"] = torch.ones(
+ config.num_decoder_layers, config.num_heads, device=torch_device
+ )
+
+ out = model.generate(
+ config_and_inputs[1],
+ num_beams=1,
+ max_length=max_length,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ **head_masks,
+ )
+ # We check the state of decoder_attentions and cross_attentions just from the last step
+ attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
+ self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ block_len = getattr(self.model_tester, "block_len", None)
+
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+ def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
+ block_len = getattr(self.model_tester, "block_len", None)
+ encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len)
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [layer_attentions.shape for layer_attentions in attentions],
+ [encoder_expected_shape] * len(attentions),
+ )
+
+
+@require_torch
+class LongT5TGlobalModelTest(LongT5ModelTest):
+ def setUp(self):
+ self.model_tester = LongT5ModelTester(
+ self, encoder_attention_type="transient-global", large_model_config_path="google/long-t5-tglobal-large"
+ )
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_len = encoder_seq_length // global_block_size
+
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+ def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_length = seq_length // global_block_size
+ encoder_expected_shape = (
+ batch_size,
+ 1,
+ config.num_attention_heads,
+ block_len,
+ 3 * block_len + global_seq_length,
+ )
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [layer_attentions.shape for layer_attentions in attentions],
+ [encoder_expected_shape] * len(attentions),
+ )
+
+
+class LongT5EncoderOnlyModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ use_attention_mask=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ is_training=False,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ is_encoder_decoder=False,
+ eos_token_id=1,
+ pad_token_id=0,
+ scope=None,
+ large_model_config_path="google/long-t5-local-large",
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.encoder_seq_length
+ self.use_attention_mask = use_attention_mask
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.is_encoder_decoder = is_encoder_decoder
+ self.scope = None
+ self.is_training = is_training
+ self.large_model_config_path = large_model_config_path
+
+ def get_large_model_config(self):
+ return LongT5Config.from_pretrained(self.large_model_config_path)
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+
+ config = LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ is_encoder_decoder=self.is_encoder_decoder,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ ):
+ model = LongT5EncoderModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+ result = model(input_ids=input_ids)
+ encoder_output = result.last_hidden_state
+
+ self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (LongT5EncoderModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_torchscript = True
+ test_resize_embeddings = False
+ test_model_parallel = False
+
+ def setUp(self):
+ self.model_tester = LongT5EncoderOnlyModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ block_len = getattr(self.model_tester, "block_len", 4)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+
+class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
+ def setUp(self):
+ self.model_tester = LongT5EncoderOnlyModelTester(
+ self, encoder_attention_type="transient-global", large_model_config_path="google/long-t5-tglobal-large"
+ )
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ block_len = getattr(self.model_tester, "block_len", None)
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", 4)
+ global_seq_len = seq_len // global_block_size
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+
+def use_task_specific_params(model, task):
+ model.config.update(model.config.task_specific_params[task])
+
+
+@require_torch
+@require_sentencepiece
+@require_tokenizers
+class LongT5ModelIntegrationTests(unittest.TestCase):
+ @cached_property
+ def model(self):
+ return LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps").to(
+ torch_device
+ )
+
+ @cached_property
+ def tokenizer(self):
+ return AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
+
+ def expected_summary(self):
+ return [
+ "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
+ " developing world . it provides an excellent resolution for visualization of the coronaryarteries for"
+ " catheter - based or operating interventions . although the association of this technique with major"
+ " complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
+ " and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the"
+ " diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for"
+ " major noncoron ary cardiac surgery referred"
+ ]
+
+ @slow
+ def test_summarization(self):
+ model = self.model
+ tok = self.tokenizer
+
+ ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
+ although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
+ a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
+ wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
+ 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
+ ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
+ aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
+ roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
+ ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
+ duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
+ nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
+ ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
+ ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
+ y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
+ re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
+ ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
+ aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
+ gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
+ ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
+ hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
+ lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
+ hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
+ e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
+ used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
+ agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
+ positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
+ er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
+ lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
+ irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
+ giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
+ software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
+ ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
+ s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
+ icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
+ confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
+ in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
+ double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
+ dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
+ coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
+ iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
+ who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
+ tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
+ 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
+ patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
+ or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
+ in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
+ nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
+ retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
+ lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
+ iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
+ a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
+ could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
+ ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
+ s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
+ rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
+
+ dct = tok(
+ [ARTICLE],
+ max_length=1024,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).to(torch_device)
+
+ hypotheses_batch = model.generate(
+ **dct,
+ num_beams=4,
+ length_penalty=2.0,
+ max_length=142,
+ min_length=56,
+ no_repeat_ngram_size=3,
+ do_sample=False,
+ early_stopping=True,
+ )
+
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ self.assertListEqual(
+ self.expected_summary(),
+ decoded,
+ )
+
+ @slow
+ def test_inference_hidden_states(self):
+ model = self.model
+
+ input_ids = torch.tensor(
+ [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+ decoder_input_ids = torch.tensor(
+ [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+ attention_mask = torch.tensor(
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ output = model(
+ input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True
+ )
+
+ # check if encoder_outputs match
+ expected_output_slice = torch.tensor([0.0629, -0.1294, -0.0089, 0.0772, 0.0663], device=torch_device)
+ self.assertTrue(torch.allclose(output.encoder_hidden_states[-1][0, 0, :5], expected_output_slice, atol=1e-4))
+
+ # check if logits match
+ expected_output_slice = torch.tensor([5.5231, 6.1058, 3.1766, 8.2391, -5.9453], device=torch_device)
+ self.assertTrue(torch.allclose(output.logits[0, 0, :5], expected_output_slice, atol=1e-4))
diff --git a/tests/models/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py
index 0661748da5a0..789988d5ca35 100644
--- a/tests/models/luke/test_modeling_luke.py
+++ b/tests/models/luke/test_modeling_luke.py
@@ -30,6 +30,10 @@
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
+ LukeForMultipleChoice,
+ LukeForQuestionAnswering,
+ LukeForSequenceClassification,
+ LukeForTokenClassification,
LukeModel,
LukeTokenizer,
)
@@ -66,6 +70,8 @@ def __init__(
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
num_entity_classification_labels=9,
num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4,
@@ -99,6 +105,8 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels
@@ -139,7 +147,8 @@ def prepare_config_and_inputs(self):
)
sequence_labels = None
- labels = None
+ token_labels = None
+ choice_labels = None
entity_labels = None
entity_classification_labels = None
entity_pair_classification_labels = None
@@ -147,7 +156,9 @@ def prepare_config_and_inputs(self):
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
- labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
@@ -170,7 +181,8 @@ def prepare_config_and_inputs(self):
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -207,7 +219,8 @@ def create_and_check_model(
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -247,7 +260,8 @@ def create_and_check_for_masked_lm(
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -266,13 +280,16 @@ def create_and_check_for_masked_lm(
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
- labels=labels,
+ labels=token_labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
- self.parent.assertEqual(
- result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
- )
+ if entity_ids is not None:
+ self.parent.assertEqual(
+ result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
+ )
+ else:
+ self.parent.assertIsNone(result.entity_logits)
def create_and_check_for_entity_classification(
self,
@@ -285,7 +302,8 @@ def create_and_check_for_entity_classification(
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -319,7 +337,8 @@ def create_and_check_for_entity_pair_classification(
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -353,7 +372,8 @@ def create_and_check_for_entity_span_classification(
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -383,6 +403,156 @@ def create_and_check_for_entity_span_classification(
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
)
+ def create_and_check_for_question_answering(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ entity_ids,
+ entity_attention_mask,
+ entity_token_type_ids,
+ entity_position_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ entity_labels,
+ entity_classification_labels,
+ entity_pair_classification_labels,
+ entity_span_classification_labels,
+ ):
+ model = LukeForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def create_and_check_for_sequence_classification(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ entity_ids,
+ entity_attention_mask,
+ entity_token_type_ids,
+ entity_position_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ entity_labels,
+ entity_classification_labels,
+ entity_pair_classification_labels,
+ entity_span_classification_labels,
+ ):
+ config.num_labels = self.num_labels
+ model = LukeForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ labels=sequence_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_for_token_classification(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ entity_ids,
+ entity_attention_mask,
+ entity_token_type_ids,
+ entity_position_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ entity_labels,
+ entity_classification_labels,
+ entity_pair_classification_labels,
+ entity_span_classification_labels,
+ ):
+ config.num_labels = self.num_labels
+ model = LukeForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ entity_ids=entity_ids,
+ entity_attention_mask=entity_attention_mask,
+ entity_token_type_ids=entity_token_type_ids,
+ entity_position_ids=entity_position_ids,
+ labels=token_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
+ def create_and_check_for_multiple_choice(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ entity_ids,
+ entity_attention_mask,
+ entity_token_type_ids,
+ entity_position_ids,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ entity_labels,
+ entity_classification_labels,
+ entity_pair_classification_labels,
+ entity_span_classification_labels,
+ ):
+ config.num_choices = self.num_choices
+ model = LukeForMultipleChoice(config=config)
+ model.to(torch_device)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_entity_token_type_ids = (
+ entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ )
+ multiple_choice_entity_attention_mask = (
+ entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ )
+ multiple_choice_entity_position_ids = (
+ entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
+ )
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_attention_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ entity_ids=multiple_choice_entity_ids,
+ entity_attention_mask=multiple_choice_entity_attention_mask,
+ entity_token_type_ids=multiple_choice_entity_token_type_ids,
+ entity_position_ids=multiple_choice_entity_position_ids,
+ labels=choice_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -395,7 +565,8 @@ def prepare_config_and_inputs_for_common(self):
entity_token_type_ids,
entity_position_ids,
sequence_labels,
- labels,
+ token_labels,
+ choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
@@ -423,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
+ LukeForQuestionAnswering,
+ LukeForSequenceClassification,
+ LukeForTokenClassification,
+ LukeForMultipleChoice,
)
if is_torch_available()
else ()
@@ -433,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
+ inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
+
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+ if model_class == LukeForMultipleChoice:
+ entity_inputs_dict = {
+ k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
+ if v.ndim == 2
+ else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
+ for k, v in entity_inputs_dict.items()
+ }
+ inputs_dict.update(entity_inputs_dict)
+
if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
@@ -443,7 +630,12 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
)
if return_labels:
- if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
+ if model_class in (
+ LukeForEntityClassification,
+ LukeForEntityPairClassification,
+ LukeForSequenceClassification,
+ LukeForMultipleChoice,
+ ):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
@@ -453,6 +645,12 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
dtype=torch.long,
device=torch_device,
)
+ elif model_class == LukeForTokenClassification:
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.seq_length),
+ dtype=torch.long,
+ device=torch_device,
+ )
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
@@ -488,6 +686,27 @@ def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+ def test_for_masked_lm_with_word_only(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
@@ -624,7 +843,10 @@ def test_inference_base_model(self):
model.to(torch_device)
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
@@ -656,7 +878,10 @@ def test_inference_large_model(self):
model.to(torch_device)
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large", task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
diff --git a/tests/models/luke/test_tokenization_luke.py b/tests/models/luke/test_tokenization_luke.py
index 81dce277a385..aa208f950bf3 100644
--- a/tests/models/luke/test_tokenization_luke.py
+++ b/tests/models/luke/test_tokenization_luke.py
@@ -480,7 +480,10 @@ def test_text_pair_padding_pytorch_tensors(self):
def test_entity_classification_no_padding_or_truncation(self):
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification")
- sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ sentence = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True)
@@ -491,7 +494,8 @@ def test_entity_classification_no_padding_or_truncation(self):
self.assertEqual(len(encoding["token_type_ids"]), 42)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.",
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous"
+ " netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][9:12], spaces_between_special_tokens=False), " she"
@@ -514,7 +518,10 @@ def test_entity_classification_padding_pytorch_tensors(self):
tokenizer = LukeTokenizer.from_pretrained(
"studio-ousia/luke-base", task="entity_classification", return_token_type_ids=True
)
- sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ sentence = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
# entity information
span = (39, 42)
diff --git a/tests/models/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py
index 7061aaa7d379..1c51d02e96b7 100644
--- a/tests/models/lxmert/test_modeling_lxmert.py
+++ b/tests/models/lxmert/test_modeling_lxmert.py
@@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
+ fx_compatible = True
test_head_masking = False
test_pruning = False
test_torchscript = False
diff --git a/tests/models/lxmert/test_modeling_tf_lxmert.py b/tests/models/lxmert/test_modeling_tf_lxmert.py
index 7594f889189c..73eda47eb950 100644
--- a/tests/models/lxmert/test_modeling_tf_lxmert.py
+++ b/tests/models/lxmert/test_modeling_tf_lxmert.py
@@ -20,7 +20,7 @@
import numpy as np
from transformers import LxmertConfig, is_tf_available
-from transformers.testing_utils import require_tf, slow
+from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@@ -600,8 +600,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
@slow
diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py
index 003c5a57cf67..0d5bdc3ca303 100644
--- a/tests/models/m2m_100/test_modeling_m2m_100.py
+++ b/tests/models/m2m_100/test_modeling_m2m_100.py
@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -354,7 +355,9 @@ def test_seq_to_seq_generation(self):
src_fr = [
"L'affaire NSA souligne l'absence totale de dƩbat sur le renseignement",
"Selon moi, il y a deux niveaux de rƩponse de la part du gouvernement franƧais.",
- "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
+ "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
+ " Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
+ " l'ampleur de la surveillance amƩricaine sur l'ensemble des communications en France.",
]
# The below article tests that we don't add any hypotheses outside of the top n_beams
@@ -370,7 +373,9 @@ def test_seq_to_seq_generation(self):
expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
- "When FranƧois Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S. Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all communications in France.",
+ "When FranƧois Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
+ " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
+ " communications in France.",
]
generated = tokenizer.batch_decode(
diff --git a/tests/models/m2m_100/test_tokenization_m2m_100.py b/tests/models/m2m_100/test_tokenization_m2m_100.py
index 729deb6cd486..ca8349d94016 100644
--- a/tests/models/m2m_100/test_tokenization_m2m_100.py
+++ b/tests/models/m2m_100/test_tokenization_m2m_100.py
@@ -187,9 +187,7 @@ def test_batch_fairseq_parity(self):
self.tokenizer.src_lang = "en"
self.tokenizer.tgt_lang = "fr"
- batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
- with self.tokenizer.as_target_tokenizer():
- batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
+ batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
batch["decoder_input_ids"] = shift_tokens_right(
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
@@ -217,17 +215,19 @@ def test_src_lang_setter(self):
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
@require_torch
- def test_as_target_tokenizer(self):
+ def test_tokenizer_target_mode(self):
self.tokenizer.tgt_lang = "mr"
- with self.tokenizer.as_target_tokenizer():
- self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
- self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
+ self.tokenizer._switch_to_target_mode()
+ self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
+ self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
+ self.tokenizer._switch_to_input_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
self.tokenizer.tgt_lang = "zh"
- with self.tokenizer.as_target_tokenizer():
- self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
- self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
+ self.tokenizer._switch_to_target_mode()
+ self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
+ self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
+ self.tokenizer._switch_to_input_mode()
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
@require_torch
diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py
index 9c119f936af4..6ca951e37aed 100644
--- a/tests/models/marian/test_modeling_marian.py
+++ b/tests/models/marian/test_modeling_marian.py
@@ -123,6 +123,12 @@ def __init__(
self.bos_token_id = bos_token_id
self.decoder_start_token_id = decoder_start_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
@@ -152,6 +158,8 @@ def get_config(self):
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def prepare_config_and_inputs_for_common(self):
@@ -230,6 +238,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -429,10 +438,7 @@ def test_forward(self):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected_ids = [38, 121, 14, 697, 38848, 0]
- model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(tgt, return_tensors="pt")
- model_inputs["labels"] = targets["input_ids"].to(torch_device)
+ model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
diff --git a/tests/models/marian/test_modeling_tf_marian.py b/tests/models/marian/test_modeling_tf_marian.py
index e62d7f0d35cc..e8d65e0ad0ea 100644
--- a/tests/models/marian/test_modeling_tf_marian.py
+++ b/tests/models/marian/test_modeling_tf_marian.py
@@ -19,7 +19,7 @@
import warnings
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
-from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -246,8 +246,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):
diff --git a/tests/models/marian/test_tokenization_marian.py b/tests/models/marian/test_tokenization_marian.py
index 2cbc0b0a3fe7..6a079036bb6d 100644
--- a/tests/models/marian/test_tokenization_marian.py
+++ b/tests/models/marian/test_tokenization_marian.py
@@ -145,9 +145,8 @@ def test_tokenizer_integration_seperate_vocabs(self):
src_ids = tokenizer(source_text).input_ids
self.assertListEqual(src_ids, expected_src_ids)
- with tokenizer.as_target_tokenizer():
- target_ids = tokenizer(target_text).input_ids
- self.assertListEqual(target_ids, expected_target_ids)
+ target_ids = tokenizer(text_target=target_text).input_ids
+ self.assertListEqual(target_ids, expected_target_ids)
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
self.assertEqual(decoded, target_text)
diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py
index bbc24719d753..b1e61210612f 100644
--- a/tests/models/maskformer/test_modeling_maskformer.py
+++ b/tests/models/maskformer/test_modeling_maskformer.py
@@ -21,7 +21,7 @@
from tests.test_modeling_common import floats_tensor
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -212,6 +212,13 @@ def test_generate_without_input_ids(self):
def test_resize_tokens_embeddings(self):
pass
+ @require_torch_multi_gpu
+ @unittest.skip(
+ reason="MaskFormer has some layers using `add_module` which doesn't work well with `nn.DataParallel`"
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -380,9 +387,12 @@ def test_inference_instance_segmentation_head(self):
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
)
- expected_slice = torch.tensor(
- [[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
- ).to(torch_device)
+ expected_slice = [
+ [-1.3737124, -1.7724937, -1.9364233],
+ [-1.5977281, -1.9867939, -2.1523695],
+ [-1.5795398, -1.9269832, -2.093942],
+ ]
+ expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py
index 3ac2c542da1e..11f8bd7a0d51 100644
--- a/tests/models/mbart/test_modeling_mbart.py
+++ b/tests/models/mbart/test_modeling_mbart.py
@@ -113,6 +113,12 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@@ -142,6 +148,8 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def prepare_config_and_inputs_for_common(self):
@@ -224,6 +232,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -348,7 +357,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£a Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£a Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, 250004]
diff --git a/tests/models/mbart/test_modeling_tf_mbart.py b/tests/models/mbart/test_modeling_tf_mbart.py
index 559a44e5db6a..b1bdb40cf79f 100644
--- a/tests/models/mbart/test_modeling_tf_mbart.py
+++ b/tests/models/mbart/test_modeling_tf_mbart.py
@@ -17,7 +17,7 @@
import unittest
from transformers import AutoTokenizer, MBartConfig, is_tf_available
-from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -281,8 +281,8 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
diff --git a/tests/models/mbart/test_tokenization_mbart.py b/tests/models/mbart/test_tokenization_mbart.py
index d24aefb01fd9..f65662dbe247 100644
--- a/tests/models/mbart/test_tokenization_mbart.py
+++ b/tests/models/mbart/test_tokenization_mbart.py
@@ -213,7 +213,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@@ -263,33 +265,27 @@ def test_special_tokens_unaffacted_by_save_load(self):
@require_torch
def test_batch_fairseq_parity(self):
- batch = self.tokenizer(self.src_text, padding=True)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
+ batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
- assert batch.input_ids[1][-2:] == [2, EN_CODE]
- assert batch.decoder_input_ids[1][0] == RO_CODE
+ assert batch.input_ids[1][-2:].tolist() == [2, EN_CODE]
+ assert batch.decoder_input_ids[1][0].tolist() == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
- assert labels[1][-2:].tolist() == [2, RO_CODE]
+ assert batch.labels[1][-2:].tolist() == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_batch(self):
batch = self.tokenizer(
- self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
+ self.src_text,
+ text_target=self.tgt_text,
+ padding=True,
+ truncation=True,
+ max_length=len(self.expected_src_tokens),
+ return_tensors="pt",
)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(
- self.tgt_text,
- padding=True,
- truncation=True,
- max_length=len(self.expected_src_tokens),
- return_tensors="pt",
- )
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
+
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
@@ -304,8 +300,9 @@ def test_enro_tokenizer_prepare_batch(self):
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
+ targets = self.tokenizer(
+ text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
+ )
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
diff --git a/tests/models/mbart50/test_tokenization_mbart50.py b/tests/models/mbart50/test_tokenization_mbart50.py
index 63adfe8436d5..d10d51df907c 100644
--- a/tests/models/mbart50/test_tokenization_mbart50.py
+++ b/tests/models/mbart50/test_tokenization_mbart50.py
@@ -203,7 +203,9 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
@@ -254,35 +256,27 @@ def test_special_tokens_unaffacted_by_save_load(self):
@require_torch
def test_batch_fairseq_parity(self):
- batch = self.tokenizer(self.src_text, padding=True)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
- labels = labels.tolist()
+ batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
assert batch.input_ids[1][0] == EN_CODE
assert batch.input_ids[1][-1] == 2
- assert labels[1][0] == RO_CODE
- assert labels[1][-1] == 2
- assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
+ assert batch.labels[1][0] == RO_CODE
+ assert batch.labels[1][-1] == 2
+ assert batch.decoder_input_ids[1][:2].tolist() == [2, RO_CODE]
@require_torch
def test_tokenizer_prepare_batch(self):
batch = self.tokenizer(
- self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
+ self.src_text,
+ text_target=self.tgt_text,
+ padding=True,
+ truncation=True,
+ max_length=len(self.expected_src_tokens),
+ return_tensors="pt",
)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(
- self.tgt_text,
- padding=True,
- truncation=True,
- max_length=len(self.expected_src_tokens),
- return_tensors="pt",
- )
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
@@ -297,8 +291,9 @@ def test_tokenizer_prepare_batch(self):
def test_seq2seq_max_target_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
+ targets = self.tokenizer(
+ text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
+ )
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
diff --git a/tests/models/mctct/__init__.py b/tests/models/mctct/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/mctct/test_feature_extraction_mctct.py b/tests/models/mctct/test_feature_extraction_mctct.py
new file mode 100644
index 000000000000..e0c77ad450fd
--- /dev/null
+++ b/tests/models/mctct/test_feature_extraction_mctct.py
@@ -0,0 +1,274 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import random
+import unittest
+
+import numpy as np
+
+from transformers import is_speech_available
+from transformers.testing_utils import require_torch, require_torchaudio
+
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+
+
+if is_speech_available():
+ from transformers import MCTCTFeatureExtractor
+
+global_rng = random.Random()
+
+
+def floats_list(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ values = []
+ for _batch_idx in range(shape[0]):
+ values.append([])
+ for _ in range(shape[1]):
+ values[-1].append(rng.random() * scale)
+
+ return values
+
+
+@require_torch
+@require_torchaudio
+class MCTCTFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ min_seq_length=400,
+ max_seq_length=2000,
+ feature_size=24,
+ num_mel_bins=24,
+ padding_value=0.0,
+ sampling_rate=16_000,
+ return_attention_mask=True,
+ do_normalize=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.min_seq_length = min_seq_length
+ self.max_seq_length = max_seq_length
+ self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
+ self.feature_size = feature_size
+ self.num_mel_bins = num_mel_bins
+ self.padding_value = padding_value
+ self.sampling_rate = sampling_rate
+ self.return_attention_mask = return_attention_mask
+ self.do_normalize = do_normalize
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "feature_size": self.feature_size,
+ "num_mel_bins": self.num_mel_bins,
+ "padding_value": self.padding_value,
+ "sampling_rate": self.sampling_rate,
+ "return_attention_mask": self.return_attention_mask,
+ "do_normalize": self.do_normalize,
+ }
+
+ def prepare_inputs_for_common(self, equal_length=False, numpify=False):
+ def _flatten(list_of_lists):
+ return list(itertools.chain(*list_of_lists))
+
+ if equal_length:
+ speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
+ else:
+ # make sure that inputs increase in size
+ speech_inputs = [
+ floats_list((x, self.feature_size))
+ for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
+ ]
+ if numpify:
+ speech_inputs = [np.asarray(x) for x in speech_inputs]
+ return speech_inputs
+
+
+@require_torch
+@require_torchaudio
+class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
+
+ feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
+
+ def setUp(self):
+ self.feat_extract_tester = MCTCTFeatureExtractionTester(self)
+
+ def _check_zero_mean_unit_variance(self, input_vector):
+ self.assertTrue(np.all(np.mean(input_vector) < 1e-3))
+ self.assertTrue(np.all(np.abs(np.var(input_vector) - 1) < 1e-3))
+
+ def test_call(self):
+ # Tests that all call wrap to encode_plus and batch_encode_plus
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ # create three inputs of length 800, 1000, and 12000
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
+
+ # Test feature size
+ input_features = feature_extractor(np_speech_inputs, padding=True, return_tensors="np").input_features
+ self.assertTrue(input_features.ndim == 3)
+ self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
+
+ # Test not batched input
+ encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
+ encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
+ self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
+
+ # Test batched
+ encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
+ encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
+ for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
+ self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+
+ def test_cepstral_mean_and_variance_normalization(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+
+ paddings = ["longest", "max_length", "do_not_pad"]
+ max_lengths = [None, 16, None]
+ for max_length, padding in zip(max_lengths, paddings):
+ inputs = feature_extractor(
+ speech_inputs,
+ padding=padding,
+ max_length=max_length,
+ return_attention_mask=True,
+ truncation=max_length is not None, # reference to #16419
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = [np.sum(x) for x in attention_mask]
+ self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
+
+ def test_cepstral_mean_and_variance_normalization_np(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+
+ paddings = ["longest", "max_length", "do_not_pad"]
+ max_lengths = [None, 16, None]
+ for max_length, padding in zip(max_lengths, paddings):
+ inputs = feature_extractor(
+ speech_inputs,
+ max_length=max_length,
+ padding=padding,
+ return_tensors="np",
+ return_attention_mask=True,
+ truncation=max_length is not None,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = [np.sum(x) for x in attention_mask]
+
+ self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
+ self.assertTrue(input_features[0][fbank_feat_lengths[0] :].sum() < 1e-6)
+ self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
+ self.assertTrue(input_features[0][fbank_feat_lengths[1] :].sum() < 1e-6)
+ self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
+
+ def test_cepstral_mean_and_variance_normalization_trunc_max_length(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="max_length",
+ max_length=4,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="longest",
+ max_length=4,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ # make sure that if max_length < longest -> then pad to max_length
+ self.assertEqual(input_features.shape, (3, 4, 24))
+
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="longest",
+ max_length=16,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ # make sure that if max_length < longest -> then pad to max_length
+ self.assertEqual(input_features.shape, (3, 16, 24))
+
+ def test_double_precision_pad(self):
+ import torch
+
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
+ py_speech_inputs = np_speech_inputs.tolist()
+
+ for inputs in [py_speech_inputs, np_speech_inputs]:
+ np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_features.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_features.dtype == torch.float32)
+
+ def test_different_window(self):
+ import torch
+
+ init_dict = self.feat_extract_tester.prepare_feat_extract_dict()
+ init_dict["win_function"] = "hann_window"
+
+ feature_extractor = self.feature_extraction_class(**init_dict)
+ np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
+ py_speech_inputs = np_speech_inputs.tolist()
+
+ for inputs in [py_speech_inputs, np_speech_inputs]:
+ np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_features.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_features.dtype == torch.float32)
diff --git a/tests/models/mctct/test_modeling_mctct.py b/tests/models/mctct/test_modeling_mctct.py
new file mode 100644
index 000000000000..ee4a9efc2fef
--- /dev/null
+++ b/tests/models/mctct/test_modeling_mctct.py
@@ -0,0 +1,647 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch MCTCT model. """
+
+import inspect
+import math
+import unittest
+
+from datasets import load_dataset
+
+from transformers import MCTCTConfig, is_torch_available
+from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import MCTCTForCTC, MCTCTModel, MCTCTProcessor
+
+
+class MCTCTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=10,
+ seq_length=40, # speech is longer
+ is_training=False,
+ vocab_size=32,
+ hidden_size=128 * 4,
+ num_hidden_layers=4,
+ intermediate_size=20,
+ num_attention_heads=4,
+ attention_head_dim=128,
+ max_position_embeddings=920,
+ layer_norm_eps=1e-5,
+ layerdrop=0.3,
+ hidden_act="relu",
+ initializer_range=0.02,
+ hidden_dropout_prob=0.3,
+ attention_probs_dropout_prob=0.3,
+ conv_glu_dim=1,
+ conv_dropout=0.3,
+ num_conv_layers=1,
+ conv_kernel=(7,),
+ conv_stride=(3,),
+ input_feat_per_channel=80,
+ input_channels=1,
+ conv_channels=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length # speech is longer
+ self.is_training = is_training
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+
+ self.attention_head_dim = attention_head_dim
+ self.max_position_embeddings = max_position_embeddings
+
+ self.layer_norm_eps = layer_norm_eps
+ self.layerdrop = layerdrop
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+
+ self.conv_glu_dim = conv_glu_dim
+ self.conv_dropout = conv_dropout
+ self.num_conv_layers = num_conv_layers
+ self.conv_kernel = conv_kernel
+ self.conv_stride = conv_stride
+ self.input_feat_per_channel = input_feat_per_channel
+ self.input_channels = input_channels
+ self.conv_channels = conv_channels
+
+ output_seq_length = self.seq_length
+ dilation = 1
+ for _, kernel_sz, stride in zip(range(self.num_conv_layers), self.conv_kernel, self.conv_stride):
+ padding = kernel_sz // 2
+ output_seq_length = output_seq_length + 2 * padding - dilation * (kernel_sz - 1) - 1
+ output_seq_length = torch.div(output_seq_length, stride, rounding_mode="trunc") + 1
+
+ self.output_seq_length = int(math.ceil(output_seq_length))
+ self.encoder_seq_length = self.output_seq_length
+
+ def prepare_config_and_inputs(self):
+ input_features = floats_tensor(
+ [self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size
+ )
+ attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device)
+
+ config = self.get_config()
+
+ return config, input_features, attention_mask
+
+ def get_config(self):
+ return MCTCTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ intermediate_size=self.intermediate_size,
+ num_attention_heads=self.num_attention_heads,
+ attention_head_dim=self.attention_head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ layer_norm_eps=self.layer_norm_eps,
+ layerdrop=self.layerdrop,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ conv_glu_dim=self.conv_glu_dim,
+ conv_dropout=self.conv_dropout,
+ num_conv_layers=self.num_conv_layers,
+ conv_kernel=self.conv_kernel,
+ conv_stride=self.conv_stride,
+ input_feat_per_channel=self.input_feat_per_channel,
+ input_channels=self.input_channels,
+ conv_channels=self.conv_channels,
+ )
+
+ def create_and_check_model(self, config, input_features, attention_mask):
+ model = MCTCTModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_for_ctc(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 2 * config.hidden_size
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
+ )
+
+ def create_and_check_batch_inference(self, config, input_features, *args):
+ # test does not pass for models making use of `group_norm`
+ # check: https://github.com/pytorch/fairseq/issues/3227
+ model = MCTCTModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ input_features = input_features[:3]
+ attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.bool)
+
+ input_lengths = [input_features.shape[-1] // i for i in [2, 2, 1]]
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0.0
+
+ batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
+
+ for i in range(input_features.shape[0]):
+ input_slice = input_features[i : i + 1, : input_lengths[i]]
+ output = model(input_slice).last_hidden_state
+
+ batch_output = batch_outputs[i : i + 1, : output.shape[1]]
+ self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
+
+ def check_ctc_loss(self, config, input_features, *args):
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_features = input_features[:3]
+
+ # input_features is a 2D window for each sequence
+ attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.long)
+
+ # -2 since input_features is a 2D window for each sequence in batch
+ input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_ctc_training(self, config, input_features, *args):
+ config.ctc_zero_infinity = True
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lenghts are at least
+ # one shorter than logit lenghts to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_features, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_features, *args):
+ model = MCTCTForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_features, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_features": input_features, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class MCTCTModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = MCTCTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # MCTCT has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_features`
+ def test_forward_signature(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = [
+ "input_features",
+ "attention_mask",
+ "head_mask",
+ "output_attentions",
+ "output_hidden_states",
+ "return_dict",
+ ]
+ self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
+
+ # MCTCT cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # MCTCT has no inputs_embeds
+ def test_model_common_attributes(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+ config.layerdrop = 0.0
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ input_features = inputs_dict["input_features"]
+
+ input_lengths = torch.tensor(
+ [input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class MCTCTRobustModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = MCTCTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_batched_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_batch_inference(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # MCTCT has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_features`
+ def test_forward_signature(self):
+ pass
+
+ # MCTCT cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # MCTCT has no inputs_embeds
+ # and thus the `get_input_embeddings` fn
+ # is not implemented
+ def test_model_common_attributes(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_features = inputs_dict["input_features"]
+
+ input_lengths = torch.tensor(
+ [input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+@require_soundfile
+@slow
+class MCTCTModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").filter(
+ lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
+ )[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_ctc_normal(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
+ model.to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+ input_speech = self._load_datasamples(1)
+
+ input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = ["a man said to the universe, sir, i exist."]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_normal_batched(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
+ model.to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_features = inputs.input_features.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features, attention_mask=attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe, sir, i exist.",
+ '"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor."',
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_robust_batched(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large").to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+
+ input_speech = self._load_datasamples(4)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
+
+ input_features = inputs.input_features.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features, attention_mask=attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe, sir, i exist.",
+ '"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor." "',
+ "\"the cadona's chest still-dripping bloodthe acofis overstrained eyes, even the soring arena around him"
+ " with thousands of spectators retrivialities not worth-thinking about.",
+ "his instant panic was followed by a small sharp blow high on his chestr.",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/mctct/test_processor_mctct.py b/tests/models/mctct/test_processor_mctct.py
new file mode 100644
index 000000000000..821e44b48e24
--- /dev/null
+++ b/tests/models/mctct/test_processor_mctct.py
@@ -0,0 +1,146 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+from transformers import MCTCTProcessor, is_speech_available, is_torch_available
+from transformers.file_utils import FEATURE_EXTRACTOR_NAME
+from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizer
+from transformers.testing_utils import require_torch, require_torchaudio
+
+
+if is_speech_available() and is_torch_available():
+ from transformers import MCTCTFeatureExtractor
+
+ from .test_feature_extraction_mctct import floats_list
+
+
+@require_torch
+@require_torchaudio
+class MCTCTProcessorTest(unittest.TestCase):
+ def setUp(self):
+ vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+
+ self.add_kwargs_tokens_map = {
+ "pad_token": "",
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ }
+ feature_extractor_map = {
+ "feature_size": 1,
+ "padding_value": 0.0,
+ "sampling_rate": 16000,
+ "return_attention_mask": False,
+ "do_normalize": True,
+ }
+
+ self.tmpdirname = tempfile.mkdtemp()
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+
+ with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(feature_extractor_map) + "\n")
+
+ def get_tokenizer(self, **kwargs_init):
+ kwargs = self.add_kwargs_tokens_map.copy()
+ kwargs.update(kwargs_init)
+ return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return MCTCTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = MCTCTProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, MCTCTFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = MCTCTProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = MCTCTProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, MCTCTFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ raw_speech = floats_list((3, 1000))
+
+ input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
+ input_processor = processor(raw_speech, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "This is a test string"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
diff --git a/tests/models/mluke/test_tokenization_mluke.py b/tests/models/mluke/test_tokenization_mluke.py
index 66d669924652..681825c7dccf 100644
--- a/tests/models/mluke/test_tokenization_mluke.py
+++ b/tests/models/mluke/test_tokenization_mluke.py
@@ -365,7 +365,8 @@ def test_text_pair_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan ).",
+ " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan"
+ " ).",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
@@ -423,7 +424,8 @@ def test_text_pair_only_entity_spans_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan ).",
+ " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan"
+ " ).",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
@@ -506,7 +508,8 @@ def test_entity_classification_no_padding_or_truncation(self):
self.assertEqual(len(encoding["token_type_ids"]), 23)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " Japanese is anEast Asian languagespoken by about 128 million people, primarily in Japan.",
+ " Japanese is anEast Asian languagespoken by about 128 million people, primarily in"
+ " Japan.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][4:9], spaces_between_special_tokens=False),
@@ -559,7 +562,8 @@ def test_entity_pair_classification_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- "Japaneseis an East Asian language spoken by about 128 million people, primarily inJapan.",
+ "Japaneseis an East Asian language spoken by about 128 million people, primarily"
+ " inJapan.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:4], spaces_between_special_tokens=False),
diff --git a/tests/models/mobilebert/test_modeling_tf_mobilebert.py b/tests/models/mobilebert/test_modeling_tf_mobilebert.py
index 9db55cec2d58..1800cd3ca143 100644
--- a/tests/models/mobilebert/test_modeling_tf_mobilebert.py
+++ b/tests/models/mobilebert/test_modeling_tf_mobilebert.py
@@ -17,7 +17,7 @@
import unittest
from transformers import MobileBertConfig, is_tf_available
-from transformers.testing_utils import require_tf, slow
+from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@@ -306,8 +306,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
@slow
diff --git a/tests/models/mobilebert/test_tokenization_mobilebert.py b/tests/models/mobilebert/test_tokenization_mobilebert.py
new file mode 100644
index 000000000000..395f4a2aab2c
--- /dev/null
+++ b/tests/models/mobilebert/test_tokenization_mobilebert.py
@@ -0,0 +1,345 @@
+# coding=utf-8
+# Copyright 2022 Leon Derczynski. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the MobileBERT tokenizer. """
+
+
+import os
+import unittest
+
+from transformers import MobileBertTokenizer, MobileBertTokenizerFast
+from transformers.models.bert.tokenization_bert import (
+ VOCAB_FILES_NAMES,
+ BasicTokenizer,
+ WordpieceTokenizer,
+ _is_control,
+ _is_punctuation,
+ _is_whitespace,
+)
+from transformers.testing_utils import require_tokenizers, slow
+
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
+
+
+# Copied from transformers.tests.models.bert.test_modeling_bert.py with Bert->MobileBert and pathfix
+@require_tokenizers
+class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = MobileBertTokenizer
+ rust_tokenizer_class = MobileBertTokenizerFast
+ test_rust_tokenizer = True
+ space_between_special_tokens = True
+ from_pretrained_filter = filter_non_english
+ pre_trained_model_path = "google/mobilebert-uncased"
+
+ def setUp(self):
+ super().setUp()
+
+ vocab_tokens = [
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[PAD]",
+ "[MASK]",
+ "want",
+ "##want",
+ "##ed",
+ "wa",
+ "un",
+ "runn",
+ "##ing",
+ ",",
+ "low",
+ "lowest",
+ ]
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
+
+ self.tokenizers_list = [
+ (tokenizer_def[0], self.pre_trained_model_path, tokenizer_def[2]) # else the 'google/' prefix is stripped
+ for tokenizer_def in self.tokenizers_list
+ ]
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "UNwant\u00E9d,running"
+ output_text = "unwanted, running"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file)
+
+ tokens = tokenizer.tokenize("UNwant\u00E9d,running")
+ self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ # With lower casing
+ tokenizer = self.get_tokenizer(do_lower_case=True)
+ rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_chinese(self):
+ tokenizer = BasicTokenizer()
+
+ self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
+
+ def test_basic_tokenizer_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hƤllo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
+
+ def test_basic_tokenizer_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_default(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_no_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HƤLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_respects_never_split_tokens(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
+ )
+
+ def test_wordpiece_tokenizer(self):
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
+
+ vocab = {}
+ for i, token in enumerate(vocab_tokens):
+ vocab[token] = i
+ tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
+
+ self.assertListEqual(tokenizer.tokenize(""), [])
+
+ self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
+
+ self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+
+ def test_is_whitespace(self):
+ self.assertTrue(_is_whitespace(" "))
+ self.assertTrue(_is_whitespace("\t"))
+ self.assertTrue(_is_whitespace("\r"))
+ self.assertTrue(_is_whitespace("\n"))
+ self.assertTrue(_is_whitespace("\u00A0"))
+
+ self.assertFalse(_is_whitespace("A"))
+ self.assertFalse(_is_whitespace("-"))
+
+ def test_is_control(self):
+ self.assertTrue(_is_control("\u0005"))
+
+ self.assertFalse(_is_control("A"))
+ self.assertFalse(_is_control(" "))
+ self.assertFalse(_is_control("\t"))
+ self.assertFalse(_is_control("\r"))
+
+ def test_is_punctuation(self):
+ self.assertTrue(_is_punctuation("-"))
+ self.assertTrue(_is_punctuation("$"))
+ self.assertTrue(_is_punctuation("`"))
+ self.assertTrue(_is_punctuation("."))
+
+ self.assertFalse(_is_punctuation("A"))
+ self.assertFalse(_is_punctuation(" "))
+
+ def test_clean_text(self):
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
+ self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])
+
+ self.assertListEqual(
+ [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
+ )
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("google/mobilebert-uncased")
+
+ text = tokenizer.encode("sequence builders", add_special_tokens=False)
+ text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
+
+ encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_sentence == [101] + text + [102]
+ assert encoded_pair == [101] + text + [102] + text_2 + [102]
+
+ def test_offsets_with_special_characters(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ sentence = f"A, naĆÆve {tokenizer_r.mask_token} AllenNLP sentence."
+ tokens = tokenizer_r.encode_plus(
+ sentence,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+
+ do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
+ expected_results = (
+ [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "A"),
+ ((1, 2), ","),
+ ((3, 5), "na"),
+ ((5, 6), "##ĆÆ"),
+ ((6, 8), "##ve"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "Allen"),
+ ((21, 23), "##NL"),
+ ((23, 24), "##P"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ if not do_lower_case
+ else [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "a"),
+ ((1, 2), ","),
+ ((3, 8), "naive"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "allen"),
+ ((21, 23), "##nl"),
+ ((23, 24), "##p"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ )
+
+ self.assertEqual(
+ [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
+ )
+ self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+
+ def test_change_tokenize_chinese_chars(self):
+ list_of_commun_chinese_char = ["ē", "äŗŗ", "ę"]
+ text_with_chinese_char = "".join(list_of_commun_chinese_char)
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ kwargs["tokenize_chinese_chars"] = True
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that each Chinese character is not preceded by "##"
+ self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
+ self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
+
+ kwargs["tokenize_chinese_chars"] = False
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that only the first Chinese character is not preceded by "##".
+ expected_tokens = [
+ f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
+ ]
+ self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
+ self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
diff --git a/tests/models/mobilevit/__init__.py b/tests/models/mobilevit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/mobilevit/test_feature_extraction_mobilevit.py b/tests/models/mobilevit/test_feature_extraction_mobilevit.py
new file mode 100644
index 000000000000..f13267c541c9
--- /dev/null
+++ b/tests/models/mobilevit/test_feature_extraction_mobilevit.py
@@ -0,0 +1,191 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import MobileViTFeatureExtractor
+
+
+class MobileViTFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=20,
+ do_center_crop=True,
+ crop_size=18,
+ do_flip_channel_order=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_flip_channel_order = do_flip_channel_order
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "do_flip_channel_order": self.do_flip_channel_order,
+ }
+
+
+@require_torch
+@require_vision
+class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = MobileViTFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = MobileViTFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "do_flip_channel_order"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
diff --git a/tests/models/mobilevit/test_modeling_mobilevit.py b/tests/models/mobilevit/test_modeling_mobilevit.py
new file mode 100644
index 000000000000..84ffc7b89bc5
--- /dev/null
+++ b/tests/models/mobilevit/test_modeling_mobilevit.py
@@ -0,0 +1,342 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch MobileViT model. """
+
+
+import inspect
+import unittest
+
+from transformers import MobileViTConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import MobileViTForImageClassification, MobileViTForSemanticSegmentation, MobileViTModel
+ from transformers.models.mobilevit.modeling_mobilevit import MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import MobileViTFeatureExtractor
+
+
+class MobileViTConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "neck_hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
+
+
+class MobileViTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ last_hidden_size=640,
+ num_attention_heads=4,
+ hidden_act="silu",
+ conv_kernel_size=3,
+ output_stride=32,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ classifier_dropout_prob=0.1,
+ initializer_range=0.02,
+ is_training=True,
+ use_labels=True,
+ num_labels=10,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.last_hidden_size = last_hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.conv_kernel_size = conv_kernel_size
+ self.output_stride = output_stride
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.classifier_dropout_prob = classifier_dropout_prob
+ self.use_labels = use_labels
+ self.is_training = is_training
+ self.num_labels = num_labels
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ pixel_labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+ pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels, pixel_labels
+
+ def get_config(self):
+ return MobileViTConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ num_attention_heads=self.num_attention_heads,
+ hidden_act=self.hidden_act,
+ conv_kernel_size=self.conv_kernel_size,
+ output_stride=self.output_stride,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ classifier_dropout_prob=self.classifier_dropout_prob,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
+ model = MobileViTModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (
+ self.batch_size,
+ self.last_hidden_size,
+ self.image_size // self.output_stride,
+ self.image_size // self.output_stride,
+ ),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
+ config.num_labels = self.num_labels
+ model = MobileViTForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
+ config.num_labels = self.num_labels
+ model = MobileViTForSemanticSegmentation(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (
+ self.batch_size,
+ self.num_labels,
+ self.image_size // self.output_stride,
+ self.image_size // self.output_stride,
+ ),
+ )
+ result = model(pixel_values, labels=pixel_labels)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (
+ self.batch_size,
+ self.num_labels,
+ self.image_size // self.output_stride,
+ self.image_size // self.output_stride,
+ ),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels, pixel_labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class MobileViTModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (MobileViTModel, MobileViTForImageClassification, MobileViTForSemanticSegmentation)
+ if is_torch_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = MobileViTModelTester(self)
+ self.config_tester = MobileViTConfigTester(self, config_class=MobileViTConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="MobileViT does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="MobileViT does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="MobileViT does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_stages = 5
+ self.assertEqual(len(hidden_states), expected_num_stages)
+
+ # MobileViT's feature maps are of shape (batch_size, num_channels, height, width)
+ # with the width and height being successively divided by 2.
+ divisor = 2
+ for i in range(len(hidden_states)):
+ self.assertListEqual(
+ list(hidden_states[i].shape[-2:]),
+ [self.model_tester.image_size // divisor, self.model_tester.image_size // divisor],
+ )
+ divisor *= 2
+
+ self.assertEqual(self.model_tester.output_stride, divisor // 2)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ def test_for_semantic_segmentation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = MobileViTModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class MobileViTModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small") if is_vision_available() else None
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-1.9364, -1.2327, -0.4653]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+
+ @slow
+ def test_inference_semantic_segmentation(self):
+ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
+ model = model.to(torch_device)
+
+ feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
+
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+ logits = outputs.logits
+
+ # verify the logits
+ expected_shape = torch.Size((1, 21, 32, 32))
+ self.assertEqual(logits.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [
+ [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
+ [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
+ [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
+ ],
+ device=torch_device,
+ )
+
+ self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/mt5/test_modeling_tf_mt5.py b/tests/models/mt5/test_modeling_tf_mt5.py
index 1ab1a635b396..5cbf3afb599b 100644
--- a/tests/models/mt5/test_modeling_tf_mt5.py
+++ b/tests/models/mt5/test_modeling_tf_mt5.py
@@ -67,7 +67,7 @@ def test_small_integration_test(self):
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
- mtf_score = -tf.math.reduce_sum(loss).numpy()
+ mtf_score = -tf.math.reduce_mean(loss).numpy()
- EXPECTED_SCORE = -84.9127
+ EXPECTED_SCORE = -21.228168
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 2e-4)
diff --git a/tests/models/mvp/__init__.py b/tests/models/mvp/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/mvp/test_modeling_mvp.py b/tests/models/mvp/test_modeling_mvp.py
new file mode 100644
index 000000000000..e0247d4233e8
--- /dev/null
+++ b/tests/models/mvp/test_modeling_mvp.py
@@ -0,0 +1,787 @@
+# coding=utf-8
+# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch MVP model. """
+
+
+import copy
+import tempfile
+import unittest
+
+import timeout_decorator # noqa
+
+from transformers import MvpConfig, is_torch_available
+from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
+from transformers.utils import cached_property
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MvpForCausalLM,
+ MvpForConditionalGeneration,
+ MvpForQuestionAnswering,
+ MvpForSequenceClassification,
+ MvpModel,
+ MvpTokenizer,
+ )
+ from transformers.models.mvp.modeling_mvp import MvpDecoder, MvpEncoder, shift_tokens_right
+
+
+def prepare_mvp_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = input_ids.ne(config.pad_token_id)
+ if decoder_attention_mask is None:
+ decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
+ if head_mask is None:
+ head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
+ if decoder_head_mask is None:
+ decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
+ if cross_attn_head_mask is None:
+ cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
+ return {
+ "input_ids": input_ids,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ }
+
+
+class MvpModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
+ 3,
+ )
+ input_ids[:, -1] = self.eos_token_id # Eos Token
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ config = self.get_config()
+ inputs_dict = prepare_mvp_inputs_dict(config, input_ids, decoder_input_ids)
+ return config, inputs_dict
+
+ def get_config(self):
+ return MvpConfig(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ encoder_layers=self.num_hidden_layers,
+ decoder_layers=self.num_hidden_layers,
+ encoder_attention_heads=self.num_attention_heads,
+ decoder_attention_heads=self.num_attention_heads,
+ encoder_ffn_dim=self.intermediate_size,
+ decoder_ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ )
+
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.max_position_embeddings = 100
+ config.vocab_size = 300
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = MvpModel(config=config).get_decoder().to(torch_device).eval()
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+ head_mask = inputs_dict["head_mask"]
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = ids_tensor((self.batch_size, 3), 2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def check_encoder_decoder_model_standalone(self, config, inputs_dict):
+ model = MvpModel(config=config).to(torch_device).eval()
+ outputs = model(**inputs_dict)
+
+ encoder_last_hidden_state = outputs.encoder_last_hidden_state
+ last_hidden_state = outputs.last_hidden_state
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ encoder = model.get_encoder()
+ encoder.save_pretrained(tmpdirname)
+ encoder = MvpEncoder.from_pretrained(tmpdirname).to(torch_device)
+
+ encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
+ 0
+ ]
+
+ self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ decoder = model.get_decoder()
+ decoder.save_pretrained(tmpdirname)
+ decoder = MvpDecoder.from_pretrained(tmpdirname).to(torch_device)
+
+ last_hidden_state_2 = decoder(
+ input_ids=inputs_dict["decoder_input_ids"],
+ attention_mask=inputs_dict["decoder_attention_mask"],
+ encoder_hidden_states=encoder_last_hidden_state,
+ encoder_attention_mask=inputs_dict["attention_mask"],
+ )[0]
+
+ self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
+
+
+@require_torch
+class MvpHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ input_ids = torch.tensor(
+ [
+ [71, 82, 18, 33, 46, 91, 2],
+ [68, 34, 26, 58, 30, 82, 2],
+ [5, 97, 17, 39, 94, 40, 2],
+ [76, 83, 94, 25, 70, 78, 2],
+ [87, 59, 41, 35, 48, 66, 2],
+ [55, 13, 16, 58, 5, 2, 1], # note padding
+ [64, 27, 31, 51, 12, 75, 2],
+ [52, 64, 86, 17, 83, 39, 2],
+ [48, 61, 9, 24, 71, 82, 2],
+ [26, 1, 60, 48, 22, 13, 2],
+ [21, 5, 62, 28, 14, 76, 2],
+ [45, 98, 37, 86, 59, 48, 2],
+ [70, 70, 50, 9, 28, 0, 2],
+ ],
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ batch_size = input_ids.shape[0]
+ config = MvpConfig(
+ vocab_size=self.vocab_size,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ return config, input_ids, batch_size
+
+ def test_sequence_classification_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ labels = _long_tensor([2] * batch_size).to(torch_device)
+ config.num_labels = 3
+ model = MvpForSequenceClassification(config)
+ model.to(torch_device)
+ outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
+ expected_shape = torch.Size((batch_size, config.num_labels))
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ def test_question_answering_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
+ model = MvpForQuestionAnswering(config)
+ model.to(torch_device)
+ outputs = model(
+ input_ids=input_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+
+ self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
+ self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ @timeout_decorator.timeout(1)
+ def test_lm_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
+ lm_model = MvpForConditionalGeneration(config)
+ lm_model.to(torch_device)
+ outputs = lm_model(input_ids=input_ids, labels=lm_labels)
+ expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ def test_lm_uneven_forward(self):
+ config = MvpConfig(
+ vocab_size=self.vocab_size,
+ d_model=14,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=8,
+ decoder_ffn_dim=8,
+ max_position_embeddings=48,
+ )
+ lm_model = MvpForConditionalGeneration(config).to(torch_device)
+ context = torch.tensor(
+ [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long
+ )
+ summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long)
+ outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
+ expected_shape = (*summary.shape, config.vocab_size)
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+
+ def test_generate_beam_search(self):
+ input_ids = torch.tensor([[71, 82, 2], [68, 34, 2]], device=torch_device, dtype=torch.long)
+ config = MvpConfig(
+ vocab_size=self.vocab_size,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ lm_model = MvpForConditionalGeneration(config).to(torch_device)
+ lm_model.eval()
+
+ max_length = 5
+ generated_ids = lm_model.generate(
+ input_ids.clone(),
+ do_sample=True,
+ num_return_sequences=1,
+ num_beams=2,
+ no_repeat_ngram_size=3,
+ max_length=max_length,
+ )
+ self.assertEqual(generated_ids.shape, (input_ids.shape[0], max_length))
+
+ def test_shift_tokens_right(self):
+ input_ids = torch.tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=torch.long)
+ shifted = shift_tokens_right(input_ids, 1, 2)
+ n_pad_before = input_ids.eq(1).float().sum()
+ n_pad_after = shifted.eq(1).float().sum()
+ self.assertEqual(shifted.shape, input_ids.shape)
+ self.assertEqual(n_pad_after, n_pad_before - 1)
+ self.assertTrue(torch.eq(shifted[:, 0], 2).all())
+
+ @slow
+ def test_tokenization(self):
+ tokenizer = MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+ examples = [" Hello world", " DomDramg"] # need leading spaces for equality
+ fairseq_results = [
+ torch.tensor([0, 20920, 232, 2]),
+ torch.tensor([0, 11349, 495, 4040, 571, 2]),
+ ]
+ for ex, desired_result in zip(examples, fairseq_results):
+ mvp_toks = tokenizer.encode(ex, return_tensors="pt").squeeze()
+ assert_tensors_close(desired_result.long(), mvp_toks, prefix=ex)
+
+ def test_generate_fp16(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = MvpForConditionalGeneration(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+ def test_dummy_inputs(self):
+ config, *_ = self._get_config_and_data()
+ model = MvpForConditionalGeneration(config).eval().to(torch_device)
+ model(**model.dummy_inputs)
+
+ def test_resize_tokens_embeddings_more(self):
+ config, input_ids, _ = self._get_config_and_data()
+
+ def _get_embs(m):
+ return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
+
+ model = MvpForConditionalGeneration(config).eval().to(torch_device)
+ input, output = _get_embs(model)
+ self.assertTrue(torch.eq(input, output).all())
+ new_vocab_size = 45
+ model.resize_token_embeddings(new_vocab_size)
+ input_new, output_new = _get_embs(model)
+ self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
+ self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
+ self.assertTrue(torch.eq(input_new, output_new).all())
+
+
+@require_torch
+class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (MvpModel, MvpForConditionalGeneration, MvpForSequenceClassification, MvpForQuestionAnswering)
+ if is_torch_available()
+ else ()
+ )
+ all_generative_model_classes = (MvpForConditionalGeneration,) if is_torch_available() else ()
+ is_encoder_decoder = True
+ fx_compatible = False
+ test_pruning = False
+ test_missing_keys = False
+
+ def setUp(self):
+ self.model_tester = MvpModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MvpConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_save_load_strict(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
+ self.assertEqual(info["missing_keys"], [])
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_encoder_decoder_model_standalone(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
+
+ # MvpForSequenceClassification does not support inputs_embeds
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in (MvpModel, MvpForConditionalGeneration, MvpForQuestionAnswering):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ model(**inputs)[0]
+
+ def test_generate_fp16(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = MvpForConditionalGeneration(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+
+def assert_tensors_close(a, b, atol=1e-12, prefix=""):
+ """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
+ if a is None and b is None:
+ return True
+ try:
+ if torch.allclose(a, b, atol=atol):
+ return True
+ raise
+ except Exception:
+ pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
+ if a.numel() > 100:
+ msg = f"tensor values are {pct_different:.1%} percent different."
+ else:
+ msg = f"{a} != {b}"
+ if prefix:
+ msg = prefix + ": " + msg
+ raise AssertionError(msg)
+
+
+def _long_tensor(tok_lst):
+ return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
+
+
+@require_torch
+@require_sentencepiece
+@require_tokenizers
+class MvpModelIntegrationTests(unittest.TestCase):
+ @cached_property
+ def default_tokenizer(self):
+ return MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+
+ @slow
+ def test_inference_no_head(self):
+ model = MvpModel.from_pretrained("RUCAIBox/mvp").to(torch_device)
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = input_ids.ne(model.config.pad_token_id)
+ with torch.no_grad():
+ output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
+ expected_shape = torch.Size((1, 11, 1024))
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = torch.tensor(
+ [[0.3461, 0.3624, 0.2689], [0.3461, 0.3624, 0.2689], [-0.1562, 1.1637, -0.3784]], device=torch_device
+ )
+ self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
+
+ @slow
+ def test_summarization_inference(self):
+ model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp").to(torch_device)
+ tok = self.default_tokenizer
+ # fmt: off
+ PGE_ARTICLE = """ Listen to local radio broadcasts for advertisements that reference casinos in your area.\nIf none are in your area, listen to national radio broadcasts for advertisements of casinos in other areas.\nNote the location that is mentioned in each advertisement that involves a casino.\nIf no locations are mentioned, note any additional contact information, such as a website or phone number. Use that information to find out where the casinos are.;\n,\n\nIf you learn about more than 1 casino on the radio, use the Internet to search the distance between your location and each casino. Sites such as maps.google.com or mapquest.com will help you in this search.'"""
+ # fmt: on
+ EXPECTED_SUMMARY = "Listen to the radio.\nUse the Internet."
+ dct = tok.batch_encode_plus(
+ [PGE_ARTICLE],
+ return_tensors="pt",
+ ).to(torch_device)
+
+ hypotheses_batch = model.generate(**dct)
+
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True)
+ self.assertEqual(EXPECTED_SUMMARY, decoded[0])
+
+
+class MvpStandaloneDecoderModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ d_model=16,
+ decoder_seq_length=7,
+ is_training=True,
+ is_decoder=True,
+ use_attention_mask=True,
+ use_cache=False,
+ use_labels=True,
+ decoder_start_token_id=2,
+ decoder_ffn_dim=32,
+ decoder_layers=4,
+ encoder_attention_heads=4,
+ decoder_attention_heads=4,
+ max_position_embeddings=30,
+ is_encoder_decoder=False,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.decoder_seq_length = decoder_seq_length
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+
+ self.vocab_size = vocab_size
+ self.d_model = d_model
+ self.hidden_size = d_model
+ self.num_hidden_layers = decoder_layers
+ self.decoder_layers = decoder_layers
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_attention_heads = decoder_attention_heads
+ self.num_attention_heads = decoder_attention_heads
+ self.eos_token_id = eos_token_id
+ self.bos_token_id = bos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.use_cache = use_cache
+ self.max_position_embeddings = max_position_embeddings
+ self.is_encoder_decoder = is_encoder_decoder
+
+ self.scope = None
+ self.decoder_key_length = decoder_seq_length
+ self.base_model_out_len = 2
+ self.decoder_attention_idx = 1
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ lm_labels = None
+ if self.use_labels:
+ lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ config = MvpConfig(
+ vocab_size=self.vocab_size,
+ d_model=self.d_model,
+ encoder_layers=self.decoder_layers,
+ decoder_layers=self.decoder_layers,
+ decoder_ffn_dim=self.decoder_ffn_dim,
+ encoder_attention_heads=self.encoder_attention_heads,
+ decoder_attention_heads=self.decoder_attention_heads,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ use_cache=self.use_cache,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ max_position_embeddings=self.max_position_embeddings,
+ is_encoder_decoder=self.is_encoder_decoder,
+ )
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ) = self.prepare_config_and_inputs()
+
+ encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ lm_labels,
+ )
+
+ def create_and_check_decoder_model_past(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ):
+ config.use_cache = True
+ model = MvpDecoder(config=config).to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, use_cache=True)
+ outputs_use_cache_conf = model(input_ids)
+ outputs_no_past = model(input_ids, use_cache=False)
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ past_key_values = outputs["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
+
+ def create_and_check_decoder_model_attention_mask_past(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ):
+ model = MvpDecoder(config=config).to(torch_device).eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ half_seq_length = input_ids.shape[-1] // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class MvpStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (MvpDecoder, MvpForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (MvpForCausalLM,) if is_torch_available() else ()
+ fx_comptatible = True
+ test_pruning = False
+ is_encoder_decoder = False
+
+ def setUp(
+ self,
+ ):
+ self.model_tester = MvpStandaloneDecoderModelTester(self, is_training=False)
+ self.config_tester = ConfigTester(self, config_class=MvpConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
+
+ def test_decoder_model_attn_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
+
+ def test_retain_grad_hidden_states_attentions(self):
+ # decoder cannot keep gradients
+ return
diff --git a/tests/models/mvp/test_tokenization_mvp.py b/tests/models/mvp/test_tokenization_mvp.py
new file mode 100644
index 000000000000..71e83fba0e16
--- /dev/null
+++ b/tests/models/mvp/test_tokenization_mvp.py
@@ -0,0 +1,182 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+import unittest
+
+from transformers import BatchEncoding, MvpTokenizer, MvpTokenizerFast
+from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_tokenizers, require_torch
+from transformers.utils import cached_property
+
+from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_detectors
+
+
+@require_tokenizers
+class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
+ tokenizer_class = MvpTokenizer
+ rust_tokenizer_class = MvpTokenizerFast
+ test_rust_tokenizer = True
+ from_pretrained_filter = filter_roberta_detectors
+ # from_pretrained_kwargs = {'add_prefix_space': True}
+
+ def setUp(self):
+ super().setUp()
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ ]
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ def get_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_input_output_texts(self, tokenizer):
+ return "lower newer", "lower newer"
+
+ @cached_property
+ def default_tokenizer(self):
+ return MvpTokenizer.from_pretrained("RUCAIBox/mvp")
+
+ @cached_property
+ def default_tokenizer_fast(self):
+ return MvpTokenizerFast.from_pretrained("RUCAIBox/mvp")
+
+ @require_torch
+ def test_prepare_batch(self):
+ src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
+ expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
+
+ for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
+ batch = tokenizer(src_text, max_length=len(expected_src_tokens), padding=True, return_tensors="pt")
+ self.assertIsInstance(batch, BatchEncoding)
+
+ self.assertEqual((2, 9), batch.input_ids.shape)
+ self.assertEqual((2, 9), batch.attention_mask.shape)
+ result = batch.input_ids.tolist()[0]
+ self.assertListEqual(expected_src_tokens, result)
+ # Test that special tokens are reset
+
+ @require_torch
+ def test_prepare_batch_empty_target_text(self):
+ src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
+ for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
+ batch = tokenizer(src_text, padding=True, return_tensors="pt")
+ # check if input_ids are returned and no labels
+ self.assertIn("input_ids", batch)
+ self.assertIn("attention_mask", batch)
+ self.assertNotIn("labels", batch)
+ self.assertNotIn("decoder_attention_mask", batch)
+
+ @require_torch
+ def test_tokenizer_as_target_length(self):
+ tgt_text = [
+ "Summary of the text.",
+ "Another summary.",
+ ]
+ for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
+ targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
+ self.assertEqual(32, targets["input_ids"].shape[1])
+
+ @require_torch
+ def test_prepare_batch_not_longer_than_maxlen(self):
+ for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
+ batch = tokenizer(
+ ["I am a small frog" * 1024, "I am a small frog"], padding=True, truncation=True, return_tensors="pt"
+ )
+ self.assertIsInstance(batch, BatchEncoding)
+ self.assertEqual(batch.input_ids.shape, (2, 1024))
+
+ @require_torch
+ def test_special_tokens(self):
+
+ src_text = ["A long paragraph for summarization."]
+ tgt_text = [
+ "Summary of the text.",
+ ]
+ for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
+ inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
+ input_ids = inputs["input_ids"]
+ labels = inputs["labels"]
+ self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
+ self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
+ self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
+ self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())
+
+ def test_pretokenized_inputs(self):
+ pass
+
+ def test_embeded_special_tokens(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ sentence = "A, AllenNLP sentence."
+ tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
+ tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
+
+ # token_type_ids should put 0 everywhere
+ self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
+
+ # attention_mask should put 1 everywhere, so sum over length should be 1
+ self.assertEqual(
+ sum(tokens_r["attention_mask"]) / len(tokens_r["attention_mask"]),
+ sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
+ )
+
+ tokens_r_str = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
+ tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
+
+ # Rust correctly handles the space before the mask while python doesnt
+ self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
+ self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
+
+ self.assertSequenceEqual(
+ tokens_p_str, ["", "A", ",", "", "Ä Allen", "N", "LP", "Ä sentence", ".", ""]
+ )
+ self.assertSequenceEqual(
+ tokens_r_str, ["", "A", ",", "", "Ä Allen", "N", "LP", "Ä sentence", ".", ""]
+ )
diff --git a/tests/models/nezha/__init__.py b/tests/models/nezha/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/nezha/test_modeling_nezha.py b/tests/models/nezha/test_modeling_nezha.py
new file mode 100644
index 000000000000..1083ed0796ee
--- /dev/null
+++ b/tests/models/nezha/test_modeling_nezha.py
@@ -0,0 +1,479 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+import unittest
+
+from transformers import NezhaConfig, is_torch_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_PRETRAINING_MAPPING,
+ NezhaForMaskedLM,
+ NezhaForMultipleChoice,
+ NezhaForNextSentencePrediction,
+ NezhaForPreTraining,
+ NezhaForQuestionAnswering,
+ NezhaForSequenceClassification,
+ NezhaForTokenClassification,
+ NezhaModel,
+ )
+ from transformers.models.nezha.modeling_nezha import NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+class NezhaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=128,
+ max_relative_position=32,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ """
+ Returns a tiny configuration by default.
+ """
+ return NezhaConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ def create_and_check_model(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = NezhaModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
+ result = model(input_ids, token_type_ids=token_type_ids)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def create_and_check_model_as_decoder(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ config.add_cross_attention = True
+ model = NezhaModel(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def create_and_check_for_masked_lm(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = NezhaForMaskedLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_for_next_sequence_prediction(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = NezhaForNextSentencePrediction(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=sequence_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
+
+ def create_and_check_for_pretraining(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = NezhaForPreTraining(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=token_labels,
+ next_sentence_label=sequence_labels,
+ )
+ self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
+
+ def create_and_check_for_question_answering(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = NezhaForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def create_and_check_for_sequence_classification(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_labels = self.num_labels
+ model = NezhaForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_for_token_classification(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_labels = self.num_labels
+ model = NezhaForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
+ def create_and_check_for_multiple_choice(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_choices = self.num_choices
+ model = NezhaForMultipleChoice(config=config)
+ model.to(torch_device)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_input_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ labels=choice_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class NezhaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ NezhaModel,
+ NezhaForMaskedLM,
+ NezhaForMultipleChoice,
+ NezhaForNextSentencePrediction,
+ NezhaForPreTraining,
+ NezhaForQuestionAnswering,
+ NezhaForSequenceClassification,
+ NezhaForTokenClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+ fx_compatible = True
+
+ # special case for ForPreTraining model
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
+ )
+ inputs_dict["next_sentence_label"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = NezhaModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=NezhaConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_as_decoder(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
+
+ def test_model_as_decoder_with_default_input_mask(self):
+ # This regression test was failing with PyTorch < 1.3
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ) = self.model_tester.prepare_config_and_inputs_for_decoder()
+
+ input_mask = None
+
+ self.model_tester.create_and_check_model_as_decoder(
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
+ def test_for_masked_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+
+ def test_for_next_sequence_prediction(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
+
+ def test_for_pretraining(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = NezhaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ @slow
+ @require_torch_gpu
+ def test_torchscript_device_change(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+
+ # NezhaForMultipleChoice behaves incorrectly in JIT environments.
+ if model_class == NezhaForMultipleChoice:
+ return
+
+ config.torchscript = True
+ model = model_class(config=config)
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ traced_model = torch.jit.trace(
+ model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
+ )
+
+ with tempfile.TemporaryDirectory() as tmp:
+ torch.jit.save(traced_model, os.path.join(tmp, "bert.pt"))
+ loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
+ loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
+
+
+@require_torch
+class NezhaModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference_nezha_model(self):
+ model = NezhaModel.from_pretrained("sijunhe/nezha-cn-base")
+ input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
+ attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1]])
+ with torch.no_grad():
+ output = model(input_ids, attention_mask=attention_mask)[0]
+ expected_shape = torch.Size((1, 6, 768))
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = torch.tensor([[[0.0685, 0.2441, 0.1102], [0.0600, 0.1906, 0.1349], [0.0221, 0.0819, 0.0586]]])
+
+ self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
+
+ @slow
+ def test_inference_nezha_masked_lm(self):
+ model = NezhaForMaskedLM.from_pretrained("sijunhe/nezha-cn-base")
+ input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
+ with torch.no_grad():
+ output = model(input_ids, attention_mask=attention_mask)[0]
+ expected_shape = torch.Size((1, 6, 21128))
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = torch.tensor(
+ [[-2.7939, -1.7902, -2.2189], [-2.8585, -1.8908, -2.3723], [-2.6499, -1.7750, -2.2558]]
+ )
+
+ self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
diff --git a/tests/models/nllb/__init__.py b/tests/models/nllb/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/nllb/test_tokenization_nllb.py b/tests/models/nllb/test_tokenization_nllb.py
new file mode 100644
index 000000000000..d77b101fa766
--- /dev/null
+++ b/tests/models/nllb/test_tokenization_nllb.py
@@ -0,0 +1,428 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+
+from transformers import (
+ SPIECE_UNDERLINE,
+ AddedToken,
+ BatchEncoding,
+ NllbTokenizer,
+ NllbTokenizerFast,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ get_tests_dir,
+ nested_simplify,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+)
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
+
+
+if is_torch_available():
+ from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right
+
+EN_CODE = 256047
+RO_CODE = 256145
+
+
+@require_sentencepiece
+@require_tokenizers
+class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+ tokenizer_class = NllbTokenizer
+ rust_tokenizer_class = NllbTokenizerFast
+ test_rust_tokenizer = True
+ test_sentencepiece = True
+ from_pretrained_kwargs = {}
+
+ def setUp(self):
+ super().setUp()
+
+ # We have a SentencePiece fixture for testing
+ tokenizer = NllbTokenizer(SAMPLE_VOCAB, keep_accents=True)
+ tokenizer.save_pretrained(self.tmpdirname)
+
+ def test_full_tokenizer(self):
+ tokenizer = NllbTokenizer(SAMPLE_VOCAB, keep_accents=True)
+
+ tokens = tokenizer.tokenize("This is a test")
+ self.assertListEqual(tokens, ["āThis", "āis", "āa", "āt", "est"])
+
+ self.assertListEqual(
+ tokenizer.convert_tokens_to_ids(tokens),
+ [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
+ )
+
+ tokens = tokenizer.tokenize("I was born in 92000, and this is falsƩ.")
+ self.assertListEqual(
+ tokens,
+ [
+ SPIECE_UNDERLINE + "I",
+ SPIECE_UNDERLINE + "was",
+ SPIECE_UNDERLINE + "b",
+ "or",
+ "n",
+ SPIECE_UNDERLINE + "in",
+ SPIECE_UNDERLINE + "",
+ "9",
+ "2",
+ "0",
+ "0",
+ "0",
+ ",",
+ SPIECE_UNDERLINE + "and",
+ SPIECE_UNDERLINE + "this",
+ SPIECE_UNDERLINE + "is",
+ SPIECE_UNDERLINE + "f",
+ "al",
+ "s",
+ "Ć©",
+ ".",
+ ],
+ )
+ ids = tokenizer.convert_tokens_to_ids(tokens)
+ self.assertListEqual(
+ ids,
+ [
+ value + tokenizer.fairseq_offset
+ for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
+ ],
+ )
+
+ back_tokens = tokenizer.convert_ids_to_tokens(ids)
+ self.assertListEqual(
+ back_tokens,
+ [
+ SPIECE_UNDERLINE + "I",
+ SPIECE_UNDERLINE + "was",
+ SPIECE_UNDERLINE + "b",
+ "or",
+ "n",
+ SPIECE_UNDERLINE + "in",
+ SPIECE_UNDERLINE + "",
+ "",
+ "2",
+ "0",
+ "0",
+ "0",
+ ",",
+ SPIECE_UNDERLINE + "and",
+ SPIECE_UNDERLINE + "this",
+ SPIECE_UNDERLINE + "is",
+ SPIECE_UNDERLINE + "f",
+ "al",
+ "s",
+ "",
+ ".",
+ ],
+ )
+
+ # overwrite from test_tokenization_common to speed up test
+ def test_save_pretrained(self):
+ self.tokenizers_list[0] = (self.rust_tokenizer_class, "hf-internal-testing/tiny-random-nllb", {})
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ tmpdirname2 = tempfile.mkdtemp()
+
+ tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
+ tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
+
+ # Checks it save with the same files + the tokenizer.json file for the fast one
+ self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
+ tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
+ self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
+
+ # Checks everything loads correctly in the same way
+ tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
+ tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
+
+ # Check special tokens are set accordingly on Rust and Python
+ for key in tokenizer_pp.special_tokens_map:
+ self.assertTrue(hasattr(tokenizer_rp, key))
+
+ shutil.rmtree(tmpdirname2)
+
+ # Save tokenizer rust, legacy_format=True
+ tmpdirname2 = tempfile.mkdtemp()
+
+ tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True)
+ tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
+
+ # Checks it save with the same files
+ self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
+
+ # Checks everything loads correctly in the same way
+ tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
+ tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
+
+ # Check special tokens are set accordingly on Rust and Python
+ for key in tokenizer_pp.special_tokens_map:
+ self.assertTrue(hasattr(tokenizer_rp, key))
+
+ shutil.rmtree(tmpdirname2)
+
+ # Save tokenizer rust, legacy_format=False
+ tmpdirname2 = tempfile.mkdtemp()
+
+ tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False)
+ tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
+
+ # Checks it saved the tokenizer.json file
+ self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
+
+ # Checks everything loads correctly in the same way
+ tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
+ tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
+
+ # Check special tokens are set accordingly on Rust and Python
+ for key in tokenizer_pp.special_tokens_map:
+ self.assertTrue(hasattr(tokenizer_rp, key))
+
+ shutil.rmtree(tmpdirname2)
+
+ @require_torch
+ def test_prepare_seq2seq_batch(self):
+ if not self.test_seq2seq:
+ return
+
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Longer text that will definitely require truncation.
+ src_text = [
+ " UN Chief Says There Is No Military Solution in Syria",
+ " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
+ " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
+ " will only worsen the violence and misery for millions of people.",
+ ]
+ tgt_text = [
+ "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al"
+ ' Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi'
+ " cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
+ ]
+ try:
+ batch = tokenizer.prepare_seq2seq_batch(
+ src_texts=src_text,
+ tgt_texts=tgt_text,
+ max_length=3,
+ max_target_length=10,
+ return_tensors="pt",
+ src_lang="eng_Latn",
+ tgt_lang="ron_Latn",
+ )
+ except NotImplementedError:
+ return
+ self.assertEqual(batch.input_ids.shape[1], 3)
+ self.assertEqual(batch.labels.shape[1], 10)
+ # max_target_length will default to max_length if not specified
+ batch = tokenizer.prepare_seq2seq_batch(
+ src_text, tgt_texts=tgt_text, max_length=3, return_tensors="pt"
+ )
+ self.assertEqual(batch.input_ids.shape[1], 3)
+ self.assertEqual(batch.labels.shape[1], 3)
+
+ batch_encoder_only = tokenizer.prepare_seq2seq_batch(
+ src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
+ )
+ self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
+ self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
+ self.assertNotIn("decoder_input_ids", batch_encoder_only)
+
+ @unittest.skip("Unfortunately way too slow to build a BPE with SentencePiece.")
+ def test_save_slow_from_fast_and_reload_fast(self):
+ pass
+
+ def test_special_tokens_initialization(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ added_tokens = [AddedToken("", lstrip=True)]
+
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+ r_output = tokenizer_r.encode("Hey this is a token")
+
+ special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0]
+
+ self.assertTrue(special_token_id in r_output)
+
+ if self.test_slow_tokenizer:
+ tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name,
+ additional_special_tokens=added_tokens,
+ **kwargs, # , from_slow=True <- unfortunately too slow to convert
+ )
+ tokenizer_p = self.tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+
+ p_output = tokenizer_p.encode("Hey this is a token")
+
+ cr_output = tokenizer_cr.encode("Hey this is a token")
+
+ self.assertEqual(p_output, r_output)
+ self.assertEqual(cr_output, r_output)
+ self.assertTrue(special_token_id in p_output)
+ self.assertTrue(special_token_id in cr_output)
+
+
+@require_torch
+@require_sentencepiece
+@require_tokenizers
+class NllbDistilledIntegrationTest(unittest.TestCase):
+ checkpoint_name = "facebook/nllb-200-distilled-600M"
+ src_text = [
+ " UN Chief Says There Is No Military Solution in Syria",
+ """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
+ ]
+ tgt_text = [
+ "Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
+ ]
+ expected_src_tokens = [
+ 16297,
+ 134408,
+ 8165,
+ 248066,
+ 14734,
+ 950,
+ 1135,
+ 105721,
+ 3573,
+ 83,
+ 27352,
+ 108,
+ 49486,
+ 2,
+ 256047,
+ ]
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tokenizer: NllbTokenizer = NllbTokenizer.from_pretrained(
+ cls.checkpoint_name, src_lang="eng_Latn", tgt_lang="ron_Latn"
+ )
+ cls.pad_token_id = 1
+ return cls
+
+ def test_language_codes(self):
+ self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ace_Arab"], 256001)
+ self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ace_Latn"], 256002)
+ self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["fra_Latn"], 256057)
+
+ def test_enro_tokenizer_batch_encode_plus(self):
+ ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
+ self.assertListEqual(self.expected_src_tokens, ids)
+
+ def test_enro_tokenizer_decode_ignores_language_codes(self):
+ self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
+ # fmt: off
+ generated_ids = [RO_CODE, 4254, 98068, 112923, 39072, 3909, 713, 102767, 26, 17314, 35642, 14683, 33118, 2022, 66987, 2, 256047]
+ # fmt: on
+
+ result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
+ expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
+ self.assertEqual(result, expected_romanian)
+ self.assertNotIn(self.tokenizer.eos_token, result)
+
+ def test_enro_tokenizer_truncation(self):
+ src_text = ["this is gunna be a long sentence " * 20]
+ assert isinstance(src_text[0], str)
+ desired_max_length = 10
+ ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
+ self.assertEqual(ids[-2], 2)
+ self.assertEqual(ids[-1], EN_CODE)
+ self.assertEqual(len(ids), desired_max_length)
+
+ def test_mask_token(self):
+ self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "ar_AR"]), [256203, 3])
+
+ def test_special_tokens_unaffacted_by_save_load(self):
+ tmpdirname = tempfile.mkdtemp()
+ original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
+ self.tokenizer.save_pretrained(tmpdirname)
+ new_tok = NllbTokenizer.from_pretrained(tmpdirname)
+ self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
+
+ @require_torch
+ def test_enro_tokenizer_prepare_batch(self):
+ batch = self.tokenizer(
+ self.src_text,
+ text_target=self.tgt_text,
+ padding=True,
+ truncation=True,
+ max_length=len(self.expected_src_tokens),
+ return_tensors="pt",
+ )
+ batch["decoder_input_ids"] = shift_tokens_right(
+ batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
+ )
+
+ self.assertIsInstance(batch, BatchEncoding)
+
+ self.assertEqual((2, 15), batch.input_ids.shape)
+ self.assertEqual((2, 15), batch.attention_mask.shape)
+ result = batch.input_ids.tolist()[0]
+ self.assertListEqual(self.expected_src_tokens, result)
+ self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
+ # Test that special tokens are reset
+ self.assertEqual(self.tokenizer.prefix_tokens, [])
+ self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
+
+ def test_seq2seq_max_length(self):
+ batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
+ targets = self.tokenizer(
+ text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
+ )
+ labels = targets["input_ids"]
+ batch["decoder_input_ids"] = shift_tokens_right(
+ labels,
+ self.tokenizer.pad_token_id,
+ decoder_start_token_id=self.tokenizer.lang_code_to_id[self.tokenizer.tgt_lang],
+ )
+
+ self.assertEqual(batch.input_ids.shape[1], 3)
+ self.assertEqual(batch.decoder_input_ids.shape[1], 10)
+
+ @require_torch
+ def test_tokenizer_translation(self):
+ inputs = self.tokenizer._build_translation_inputs(
+ "A test", return_tensors="pt", src_lang="eng_Latn", tgt_lang="fra_Latn"
+ )
+
+ self.assertEqual(
+ nested_simplify(inputs),
+ {
+ # A, test, EOS, en_XX
+ "input_ids": [[70, 7356, 2, 256047]],
+ "attention_mask": [[1, 1, 1, 1]],
+ # ar_AR
+ "forced_bos_token_id": 256057,
+ },
+ )
diff --git a/tests/models/opt/__init__.py b/tests/models/opt/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py
new file mode 100644
index 000000000000..208ea0c0d7a4
--- /dev/null
+++ b/tests/models/opt/test_modeling_flax_opt.py
@@ -0,0 +1,406 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import timeout_decorator # noqa
+
+from transformers import OPTConfig, is_flax_available
+from transformers.testing_utils import require_flax, require_sentencepiece, slow
+
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+
+
+if is_flax_available():
+ import os
+
+ # The slow tests are often failing with OOM error on GPU
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+
+ import jax
+ import jax.numpy as jnp
+ from transformers import FlaxOPTForCausalLM, FlaxOPTModel, GPT2Tokenizer
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+
+
+@require_flax
+class FlaxOPTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ initializer_range=0.02,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.initializer_range = initializer_range
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs(self):
+ input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
+ input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
+
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ is_encoder_decoder=False,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ initializer_range=self.initializer_range,
+ use_cache=False,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def check_use_cache_forward(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4")
+
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ attention_mask=attention_mask,
+ past_key_values=outputs_cache.past_key_values,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids, attention_mask = (
+ inputs_dict["input_ids"],
+ inputs_dict["attention_mask"],
+ )
+
+ attention_mask_cache = jnp.concatenate(
+ [
+ attention_mask,
+ jnp.zeros((attention_mask.shape[0], max_length - attention_mask.shape[1])),
+ ],
+ axis=-1,
+ )
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask_cache,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ past_key_values=outputs_cache.past_key_values,
+ attention_mask=attention_mask_cache,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids, attention_mask=attention_mask)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+
+@require_flax
+class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
+ all_model_classes = (FlaxOPTModel, FlaxOPTForCausalLM) if is_flax_available() else ()
+ all_generative_model_classes = () if is_flax_available() else ()
+
+ def setUp(self):
+ self.model_tester = FlaxOPTModelTester(self)
+
+ def test_use_cache_forward(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
+
+ def test_use_cache_forward_with_attn_mask(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_class_name in self.all_model_classes:
+ model = model_class_name.from_pretrained("facebook/opt-125m")
+ input_ids = np.ones((1, 1)) * model.config.eos_token_id
+ outputs = model(input_ids)
+ self.assertIsNotNone(outputs)
+
+
+@require_sentencepiece
+@require_flax
+class FlaxOPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = FlaxOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = jnp.array([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ output = model(input_ids=input_ids).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = jnp.array(
+ [[-0.2867, -1.9256, -0.3062], [-1.2711, -0.1337, -0.1897], [0.4109, 0.1187, -1.3142]]
+ )
+ self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_flax
+@slow
+class FlaxOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = FlaxOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="jax", padding=True, add_special_tokens=False)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ logits_meta = jnp.array(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=4e-2))
+
+ model = jax.jit(model)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=4e-2))
+
+
+@require_flax
+@slow
+class FlaxOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of New York, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_jitted_batch_generation(self):
+ model_id = "facebook/opt-125m"
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to thank",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ ]
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ inputs = tokenizer(
+ [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ ],
+ return_tensors="jax",
+ padding=True,
+ )
+
+ jit_generate = jax.jit(model.generate)
+
+ output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
+
+ output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
+
+ self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
+
+ # TODO fix in the following PR
+ # def test_batch_generation(self):
+ # model_id = "facebook/opt-350m"
+
+ # tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ # model = FlaxOPTForCausalLM.from_pretrained(model_id)
+
+ # tokenizer.padding_side = "left"
+
+ # # use different length sentences to test batching
+ # sentences = [
+ # "Hello, my dog is a little",
+ # "Today, I",
+ # ]
+
+ # inputs = tokenizer(sentences, return_tensors="jax", padding=True)
+ # input_ids = inputs["input_ids"]
+
+ # outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False)
+
+ # inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids
+ # output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ # num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum()
+ # inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids
+ # output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ # batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
+ # non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True)
+ # padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True)
+
+ # expected_output_sentence = [
+ # "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ # "Today, I"
+ # # TODO fix this test in next PR
+ # # "Today, I was in the middle of a conversation with a friend about the",
+ # ]
+ # self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ # # TODO outputs will be similar, fix in next PR
+ # self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py
new file mode 100644
index 000000000000..bdf3716b597d
--- /dev/null
+++ b/tests/models/opt/test_modeling_opt.py
@@ -0,0 +1,482 @@
+# coding=utf-8
+# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch OPT model. """
+
+
+import copy
+import tempfile
+import unittest
+
+import timeout_decorator # noqa
+
+from transformers import OPTConfig, is_torch_available
+from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel
+
+
+def prepare_opt_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = input_ids.ne(config.pad_token_id)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ }
+
+
+class OPTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ num_labels=3,
+ word_embed_proj_dim=16,
+ type_sequence_label_size=2,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.num_labels = num_labels
+ self.type_sequence_label_size = type_sequence_label_size
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
+ 3,
+ )
+ input_ids[:, -1] = self.eos_token_id # Eos Token
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ config = self.get_config()
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids, decoder_input_ids)
+ return config, inputs_dict
+
+ def get_config(self):
+ return OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ is_encoder_decoder=False,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ )
+
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.max_position_embeddings = 100
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = OPTModel(config=config).to(torch_device).eval()
+
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+ head_mask = inputs_dict["head_mask"]
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = ids_tensor((self.batch_size, 3), 2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+
+@require_torch
+class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else ()
+ all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
+ is_encoder_decoder = False
+ fx_compatible = True
+ test_pruning = False
+ test_missing_keys = False
+
+ def setUp(self):
+ self.model_tester = OPTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OPTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_save_load_strict(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
+ self.assertEqual(info["missing_keys"], [])
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in (OPTModel,):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ model(**inputs)[0]
+
+ def test_generate_fp16(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = OPTForCausalLM(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+ def test_opt_sequence_classification_model(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ config.num_labels = 3
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
+ model = OPTForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
+ self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+
+ def test_opt_sequence_classification_model_for_multi_label(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ config.num_labels = 3
+ config.problem_type = "multi_label_classification"
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ sequence_labels = ids_tensor(
+ [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
+ ).to(torch.float)
+ model = OPTForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
+ self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+
+
+def assert_tensors_close(a, b, atol=1e-12, prefix=""):
+ """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
+ if a is None and b is None:
+ return True
+ try:
+ if torch.allclose(a, b, atol=atol):
+ return True
+ raise
+ except Exception:
+ pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
+ if a.numel() > 100:
+ msg = f"tensor values are {pct_different:.1%} percent different."
+ else:
+ msg = f"{a} != {b}"
+ if prefix:
+ msg = prefix + ": " + msg
+ raise AssertionError(msg)
+
+
+def _long_tensor(tok_lst):
+ return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
+
+
+@require_torch
+class OPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+
+ with torch.no_grad():
+ output = model(input_ids=input_ids).last_hidden_state
+
+ expected_shape = torch.Size((1, 11, 512))
+ self.assertEqual(output.shape, expected_shape)
+ # expected value works for CPU, as well as GPU (with TF32 disabled)
+ expected_slice = torch.tensor(
+ [
+ [-0.28726277, -1.9241608, -0.3058734],
+ [-1.2737825, -0.13332152, -0.18766522],
+ [0.41159445, 0.1191957, -1.3107123],
+ ],
+ device=torch_device,
+ )
+ assert_tensors_close(output[0, :3, :3], expected_slice, atol=5e-5)
+
+
+@require_torch
+@slow
+class OPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_load_model(self):
+ try:
+ _ = OPTForCausalLM.from_pretrained(self.path_model)
+ except BaseException:
+ self.fail("Failed loading model")
+
+ def test_logits(self):
+ model = OPTForCausalLM.from_pretrained(self.path_model)
+ model = model.eval()
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
+ # logits_meta = torch.load(self.path_logits_meta)
+ logits_meta = torch.Tensor(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ assert torch.allclose(logits, logits_meta, atol=1e-4)
+
+
+@slow
+class OPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of New York, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_batch_generation(self):
+ model_id = "facebook/opt-350m"
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+ model.to(torch_device)
+
+ tokenizer.padding_side = "left"
+
+ # use different length sentences to test batching
+ sentences = [
+ "Hello, my dog is a little",
+ "Today, I",
+ ]
+
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
+ input_ids = inputs["input_ids"].to(torch_device)
+
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=inputs["attention_mask"].to(torch_device),
+ )
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
+ inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ "Today, I was in the middle of a conversation with a friend about the",
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ @require_torch_gpu
+ def test_batched_nan_fp16(self):
+ # a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
+ # therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
+ # please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
+ model_name = "facebook/opt-1.3b"
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
+
+ model = OPTForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
+ model = model.eval()
+
+ batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
+
+ input_ids = batch["input_ids"].cuda()
+ attention_mask = batch["attention_mask"].cuda()
+
+ with torch.no_grad():
+ outputs = model(input_ids, attention_mask=attention_mask)
+ self.assertFalse(
+ torch.isnan(outputs.logits[0]).any().item()
+ ) # the first logits could contain NaNs if it fails
diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py
new file mode 100644
index 000000000000..61d6aad53fc1
--- /dev/null
+++ b/tests/models/opt/test_modeling_tf_opt.py
@@ -0,0 +1,414 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers import OPTConfig, is_tf_available
+from transformers.testing_utils import require_sentencepiece, require_tf, slow, tooslow
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import GPT2Tokenizer, TFOPTForCausalLM, TFOPTModel
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@require_tf
+class TFOPTModelTester:
+ config_cls = OPTConfig
+ config_updates = {}
+ hidden_act = "gelu"
+
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs_for_common(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
+ eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
+ input_ids = tf.concat([input_ids, eos_tensor], axis=1)
+
+ config = self.config_cls(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ is_encoder_decoder=False,
+ **self.config_updates,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = TFOPTModel(config=config)
+ input_ids = inputs_dict["input_ids"]
+
+ input_ids = input_ids[:1, :]
+ attention_mask = inputs_dict["attention_mask"][:1, :]
+ self.batch_size = 1
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
+
+ # append to next input_ids and
+ next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
+ next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
+
+ self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
+
+ # select random slice
+ random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
+ output_from_past_slice = output_from_past[:, :, random_slice_idx]
+
+ # test that outputs are equal for slice
+ tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
+
+
+@require_tf
+class TFOPTModelTest(TFModelTesterMixin, unittest.TestCase):
+ all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else ()
+ all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else ()
+ is_encoder_decoder = False
+ test_pruning = False
+ test_onnx = False
+ onnx_min_opset = 10
+
+ def setUp(self):
+ self.model_tester = TFOPTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OPTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
+
+ if model_class in self.all_generative_model_classes:
+ x = model.get_output_embeddings()
+ assert isinstance(x, tf.keras.layers.Layer)
+ else:
+ x = model.get_output_embeddings()
+ assert x is None
+
+ def test_resize_token_embeddings(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def _get_word_embedding_weight(model, embedding_layer):
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ # Here we build the word embeddings weights if not exists.
+ # And then we retry to get the attribute once built.
+ model(model.dummy_inputs)
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ return None
+
+ for model_class in self.all_model_classes:
+ for size in [config.vocab_size - 10, config.vocab_size + 10]:
+ # build the embeddings
+ model = model_class(config=config)
+ old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # reshape the embeddings
+ model.resize_token_embeddings(size)
+ new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # check that the resized embeddings size matches the desired size.
+ assert_size = size if size is not None else config.vocab_size
+
+ self.assertEqual(new_input_embeddings.shape[0], assert_size)
+
+ # check that weights remain the same after resizing
+ models_equal = True
+ for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ if old_output_embeddings is not None and new_output_embeddings is not None:
+ self.assertEqual(new_output_embeddings.shape[0], assert_size)
+
+ models_equal = True
+ for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ @tooslow
+ def test_saved_model_creation(self):
+ pass
+
+
+def _long_tensor(tok_lst):
+ return tf.constant(tok_lst, dtype=tf.int32)
+
+
+@require_tf
+class TFOPTHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
+ input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
+ batch_size = input_ids.shape[0]
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=24,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ return config, input_ids, batch_size
+
+
+@require_sentencepiece
+@require_tf
+class OPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = TFOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = tf.not_equal(input_ids, model.config.pad_token_id)
+ with tf.GradientTape():
+ output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = tf.constant(
+ [[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]]
+ )
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-3))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ output = xla_generate(input_ids, attention_mask)[0]
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_tf
+@slow
+class TFOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = TFOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="tf", padding=True, add_special_tokens=False)
+ logits = tf.math.reduce_mean(model(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ logits_meta = tf.constant(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ logits = tf.math.reduce_mean(xla_generate(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+
+@slow
+class TFOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of New York, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_batch_generation(self):
+ model_id = "facebook/opt-350m"
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ tokenizer.padding_side = "left"
+
+ # use different length sentences to test batching
+ sentences = [
+ "Hello, my dog is a little",
+ "Today, I",
+ ]
+
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ input_ids = inputs["input_ids"]
+
+ outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - tf.math.reduce_sum(
+ tf.cast(inputs["attention_mask"][-1], tf.int64)
+ )
+ inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ "Today, I was in the middle of a conversation with a friend about the",
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
diff --git a/tests/models/owlvit/__init__.py b/tests/models/owlvit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/owlvit/test_feature_extraction_owlvit.py b/tests/models/owlvit/test_feature_extraction_owlvit.py
new file mode 100644
index 000000000000..c9198280d792
--- /dev/null
+++ b/tests/models/owlvit/test_feature_extraction_owlvit.py
@@ -0,0 +1,201 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import OwlViTFeatureExtractor
+
+
+class OwlViTFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=20,
+ do_center_crop=True,
+ crop_size=18,
+ do_normalize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ do_convert_rgb=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
+ }
+
+
+@require_torch
+@require_vision
+class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = OwlViTFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = OwlViTFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py
new file mode 100644
index 000000000000..edddc53beeab
--- /dev/null
+++ b/tests/models/owlvit/test_modeling_owlvit.py
@@ -0,0 +1,793 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch OwlViT model. """
+
+
+import inspect
+import os
+import tempfile
+import unittest
+from typing import Dict, List, Tuple
+
+import numpy as np
+
+import requests
+from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import OwlViTForObjectDetection, OwlViTModel, OwlViTTextModel, OwlViTVisionModel
+ from transformers.models.owlvit.modeling_owlvit import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import OwlViTProcessor
+
+
+class OwlViTVisionModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ dropout=0.1,
+ attention_dropout=0.1,
+ initializer_range=0.02,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ return OwlViTVisionConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ dropout=self.dropout,
+ attention_dropout=self.attention_dropout,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values):
+ model = OwlViTVisionModel(config=config).to(torch_device)
+ model.eval()
+
+ pixel_values = pixel_values.to(torch.float32)
+
+ with torch.no_grad():
+ result = model(pixel_values)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (self.image_size // self.patch_size) ** 2
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as OWLVIT does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (OwlViTVisionModel,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = OwlViTVisionModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=OwlViTVisionConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="OWLVIT does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training(self):
+ pass
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = OwlViTVisionModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class OwlViTTextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ num_queries=4,
+ seq_length=16,
+ is_training=True,
+ use_input_mask=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=64,
+ num_hidden_layers=12,
+ num_attention_heads=4,
+ intermediate_size=37,
+ dropout=0.1,
+ attention_dropout=0.1,
+ max_position_embeddings=16,
+ initializer_range=0.02,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_queries = num_queries
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size * self.num_queries, self.seq_length], self.vocab_size)
+ input_mask = None
+
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size * self.num_queries, self.seq_length])
+
+ if input_mask is not None:
+ num_text, seq_length = input_mask.shape
+
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(num_text,))
+ for idx, start_index in enumerate(rnd_start_indices):
+ input_mask[idx, :start_index] = 1
+ input_mask[idx, start_index:] = 0
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask
+
+ def get_config(self):
+ return OwlViTTextConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ dropout=self.dropout,
+ attention_dropout=self.attention_dropout,
+ max_position_embeddings=self.max_position_embeddings,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, input_ids, input_mask):
+ model = OwlViTTextModel(config=config).to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_ids=input_ids, attention_mask=input_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size * self.num_queries, self.seq_length, self.hidden_size)
+ )
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_queries, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, input_mask = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (OwlViTTextModel,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = OwlViTTextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OwlViTTextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training(self):
+ pass
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="OWLVIT does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = OwlViTTextModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class OwlViTModelTester:
+ def __init__(self, parent, is_training=True):
+ self.parent = parent
+ self.text_model_tester = OwlViTTextModelTester(parent)
+ self.vision_model_tester = OwlViTVisionModelTester(parent)
+ self.is_training = is_training
+ self.text_config = self.text_model_tester.get_config().to_dict()
+ self.vision_config = self.vision_model_tester.get_config().to_dict()
+
+ def prepare_config_and_inputs(self):
+ text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+ return config, input_ids, attention_mask, pixel_values
+
+ def get_config(self):
+ return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
+
+ def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
+ model = OwlViTModel(config).to(torch_device).eval()
+
+ with torch.no_grad():
+ result = model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ )
+
+ image_logits_size = (
+ self.vision_model_tester.batch_size,
+ self.text_model_tester.batch_size * self.text_model_tester.num_queries,
+ )
+ text_logits_size = (
+ self.text_model_tester.batch_size * self.text_model_tester.num_queries,
+ self.vision_model_tester.batch_size,
+ )
+ self.parent.assertEqual(result.logits_per_image.shape, image_logits_size)
+ self.parent.assertEqual(result.logits_per_text.shape, text_logits_size)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, attention_mask, pixel_values = config_and_inputs
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "return_loss": False,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (OwlViTModel,) if is_torch_available() else ()
+ fx_compatible = False
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = OwlViTModelTester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="Hidden_states is tested in individual model tests")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Retain_grad is tested in individual model tests")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(reason="OwlViTModel does not have input/output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ # override as the `logit_scale` parameter initilization is different for OWLVIT
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # check if `logit_scale` is initilized as per the original implementation
+ if name == "logit_scale":
+ self.assertAlmostEqual(
+ param.data.item(),
+ np.log(1 / 0.07),
+ delta=1e-3,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init).to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values
+ traced_model = torch.jit.trace(model, (input_ids, pixel_values))
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ loaded_model = loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_load_vision_text_config(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Save OwlViTConfig and check if we can load OwlViTVisionConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ vision_config = OwlViTVisionConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
+
+ # Save OwlViTConfig and check if we can load OwlViTTextConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ text_config = OwlViTTextConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = OwlViTModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class OwlViTForObjectDetectionTester:
+ def __init__(self, parent, is_training=True):
+ self.parent = parent
+ self.text_model_tester = OwlViTTextModelTester(parent)
+ self.vision_model_tester = OwlViTVisionModelTester(parent)
+ self.is_training = is_training
+ self.text_config = self.text_model_tester.get_config().to_dict()
+ self.vision_config = self.vision_model_tester.get_config().to_dict()
+
+ def prepare_config_and_inputs(self):
+ text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+ return config, pixel_values, input_ids, attention_mask
+
+ def get_config(self):
+ return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
+
+ def create_and_check_model(self, config, pixel_values, input_ids, attention_mask):
+ model = OwlViTForObjectDetection(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ )
+
+ pred_boxes_size = (
+ self.vision_model_tester.batch_size,
+ (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
+ 4,
+ )
+ pred_logits_size = (
+ self.vision_model_tester.batch_size,
+ (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
+ 4,
+ )
+ pred_class_embeds_size = (
+ self.vision_model_tester.batch_size,
+ (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
+ self.text_model_tester.hidden_size,
+ )
+ self.parent.assertEqual(result.pred_boxes.shape, pred_boxes_size)
+ self.parent.assertEqual(result.logits.shape, pred_logits_size)
+ self.parent.assertEqual(result.class_embeds.shape, pred_class_embeds_size)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, input_ids, attention_mask = config_and_inputs
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (OwlViTForObjectDetection,) if is_torch_available() else ()
+ fx_compatible = False
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = OwlViTForObjectDetectionTester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="Hidden_states is tested in individual model tests")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Retain_grad is tested in individual model tests")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(reason="OwlViTModel does not have input/output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="Test_initialization is tested in individual model tests")
+ def test_initialization(self):
+ pass
+
+ @unittest.skip(reason="Test_forward_signature is tested in individual model tests")
+ def test_forward_signature(self):
+ pass
+
+ @unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training(self):
+ pass
+
+ @unittest.skip(reason="OWL-ViT does not support training yet")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init).to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values
+ traced_model = torch.jit.trace(model, (input_ids, pixel_values))
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ loaded_model = loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_model_outputs_equivalence(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def set_nan_tensor_to_zero(t):
+ t[t != t] = 0
+ return t
+
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+ with torch.no_grad():
+ tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
+ dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif isinstance(tuple_object, Dict):
+ for tuple_iterable_value, dict_iterable_value in zip(
+ tuple_object.values(), dict_object.values()
+ ):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ torch.allclose(
+ set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
+ ),
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
+ )
+
+ recursive_check(tuple_output, dict_output)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = OwlViTForObjectDetection.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@require_vision
+@require_torch
+class OwlViTModelIntegrationTest(unittest.TestCase):
+ # @slow
+ def test_inference(self):
+ model_name = "google/owlvit-base-patch32"
+ model = OwlViTModel.from_pretrained(model_name).to(torch_device)
+ processor = OwlViTProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=[["a photo of a cat", "a photo of a dog"]],
+ images=image,
+ max_length=16,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ self.assertEqual(
+ outputs.logits_per_image.shape,
+ torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
+ )
+ self.assertEqual(
+ outputs.logits_per_text.shape,
+ torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
+ )
+ expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
+
+ @slow
+ def test_inference_object_detection(self):
+ model_name = "google/owlvit-base-patch32"
+ model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
+
+ processor = OwlViTProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=[["a photo of a cat", "a photo of a dog"]],
+ images=image,
+ max_length=16,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
+ self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
+ expected_slice_boxes = torch.tensor(
+ [[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
+ ).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
diff --git a/tests/models/owlvit/test_processor_owlvit.py b/tests/models/owlvit/test_processor_owlvit.py
new file mode 100644
index 000000000000..e37f45b15c8b
--- /dev/null
+++ b/tests/models/owlvit/test_processor_owlvit.py
@@ -0,0 +1,241 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import CLIPTokenizer, CLIPTokenizerFast
+from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_vision
+from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import OwlViTFeatureExtractor, OwlViTProcessor
+
+
+@require_vision
+class OwlViTProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.tmpdirname = tempfile.mkdtemp()
+
+ # fmt: off
+ vocab = ["", "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l", "w", "r", "t", "low", "er", "lowest", "newer", "wider", "", "<|startoftext|>", "<|endoftext|>"]
+ # fmt: on
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ feature_extractor_map = {
+ "do_resize": True,
+ "size": 20,
+ "do_center_crop": True,
+ "crop_size": 18,
+ "do_normalize": True,
+ "image_mean": [0.48145466, 0.4578275, 0.40821073],
+ "image_std": [0.26862954, 0.26130258, 0.27577711],
+ }
+ self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
+ json.dump(feature_extractor_map, fp)
+
+ def get_tokenizer(self, **kwargs):
+ return CLIPTokenizer.from_pretrained(self.tmpdirname, pad_token="!", **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ return CLIPTokenizerFast.from_pretrained(self.tmpdirname, pad_token="!", **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return OwlViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def prepare_image_inputs(self):
+ """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
+ or a list of PyTorch tensors if one specifies torchify=True.
+ """
+
+ image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
+
+ image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+
+ return image_inputs
+
+ def test_save_load_pretrained_default(self):
+ tokenizer_slow = self.get_tokenizer()
+ tokenizer_fast = self.get_rust_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor_slow = OwlViTProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor)
+ processor_slow.save_pretrained(self.tmpdirname)
+ processor_slow = OwlViTProcessor.from_pretrained(self.tmpdirname, use_fast=False)
+
+ processor_fast = OwlViTProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor)
+ processor_fast.save_pretrained(self.tmpdirname)
+ processor_fast = OwlViTProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
+ self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer)
+ self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast)
+
+ self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor_slow.feature_extractor, OwlViTFeatureExtractor)
+ self.assertIsInstance(processor_fast.feature_extractor, OwlViTFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = OwlViTProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False)
+
+ processor = OwlViTProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, OwlViTFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ image_input = self.prepare_image_inputs()
+
+ input_feat_extract = feature_extractor(image_input, return_tensors="np")
+ input_processor = processor(images=image_input, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+
+ encoded_processor = processor(text=input_str, return_tensors="np")
+
+ encoded_tok = tokenizer(input_str, return_tensors="np")
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key][0].tolist(), encoded_processor[key][0].tolist())
+
+ def test_processor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+
+ inputs = processor(text=input_str, images=image_input)
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"])
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_processor_with_text_list(self):
+ model_name = "google/owlvit-base-patch32"
+ processor = OwlViTProcessor.from_pretrained(model_name)
+
+ input_text = ["cat", "nasa badge"]
+ inputs = processor(text=input_text)
+
+ seq_length = 16
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
+ self.assertEqual(inputs["input_ids"].shape, (2, seq_length))
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_processor_with_nested_text_list(self):
+ model_name = "google/owlvit-base-patch32"
+ processor = OwlViTProcessor.from_pretrained(model_name)
+
+ input_texts = [["cat", "nasa badge"], ["person"]]
+ inputs = processor(text=input_texts)
+
+ seq_length = 16
+ batch_size = len(input_texts)
+ num_max_text_queries = max([len(texts) for texts in input_texts])
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
+ self.assertEqual(inputs["input_ids"].shape, (batch_size * num_max_text_queries, seq_length))
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_processor_case(self):
+ model_name = "google/owlvit-base-patch32"
+ processor = OwlViTProcessor.from_pretrained(model_name)
+
+ input_texts = ["cat", "nasa badge"]
+ inputs = processor(text=input_texts)
+
+ seq_length = 16
+ input_ids = inputs["input_ids"]
+ predicted_ids = [
+ [49406, 2368, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [49406, 6841, 11301, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ]
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
+ self.assertEqual(inputs["input_ids"].shape, (2, seq_length))
+ self.assertListEqual(list(input_ids[0]), predicted_ids[0])
+ self.assertListEqual(list(input_ids[1]), predicted_ids[1])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py
index a05e34e57c07..81ed90b8a96d 100644
--- a/tests/models/pegasus/test_modeling_pegasus.py
+++ b/tests/models/pegasus/test_modeling_pegasus.py
@@ -104,6 +104,12 @@ def __init__(
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
+ # forcing a certain token to be generated, sets all other tokens to -inf
+ # if however the token to be generated is already at -inf then it can lead token
+ # `nan` values and thus break generation
+ self.forced_bos_token_id = None
+ self.forced_eos_token_id = None
+
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@@ -151,6 +157,8 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
+ forced_bos_token_id=self.forced_bos_token_id,
+ forced_eos_token_id=self.forced_eos_token_id,
)
def prepare_config_and_inputs_for_common(self):
@@ -229,6 +237,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_resize_position_embeddings = True
test_pruning = False
test_missing_keys = False
diff --git a/tests/models/pegasus/test_modeling_tf_pegasus.py b/tests/models/pegasus/test_modeling_tf_pegasus.py
index 594323a7dc45..c26b25fc55e0 100644
--- a/tests/models/pegasus/test_modeling_tf_pegasus.py
+++ b/tests/models/pegasus/test_modeling_tf_pegasus.py
@@ -17,7 +17,7 @@
import unittest
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
-from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -244,8 +244,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):
@@ -339,7 +339,8 @@ class TFPegasusIntegrationTests(unittest.TestCase):
""" The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """,
]
expected_text = [
- "California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.",
+ "California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to"
+ " reduce the risk of wildfires.",
'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.',
] # differs slightly from pytorch, likely due to numerical differences in linear layers
model_name = "google/pegasus-xsum"
diff --git a/tests/models/pegasus/test_tokenization_pegasus.py b/tests/models/pegasus/test_tokenization_pegasus.py
index 3f83e84178e7..de2886a5e120 100644
--- a/tests/models/pegasus/test_tokenization_pegasus.py
+++ b/tests/models/pegasus/test_tokenization_pegasus.py
@@ -72,7 +72,10 @@ def test_vocab_size(self):
def test_mask_tokens_rust_pegasus(self):
rust_tokenizer = self.rust_tokenizer_class.from_pretrained(self.tmpdirname)
py_tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
- raw_input_str = "Let's see which is the better one It seems like this was important "
+ raw_input_str = (
+ "Let's see which is the better one It seems like this was important"
+ " "
+ )
rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
self.assertListEqual(py_ids, rust_ids)
@@ -106,10 +109,9 @@ def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
- with self._large_tokenizer.as_target_tokenizer():
- targets = self._large_tokenizer(
- tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
- )
+ targets = self._large_tokenizer(
+ text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
+ )
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
@@ -158,7 +160,10 @@ def get_input_output_texts(self, tokenizer):
def test_mask_tokens_rust_pegasus(self):
rust_tokenizer = self.rust_tokenizer_class.from_pretrained(self.tmpdirname)
py_tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
- raw_input_str = "Let's see which is the better one [MASK] It seems like this [MASK] was important "
+ raw_input_str = (
+ "Let's see which is the better one [MASK] It seems like this [MASK] was important "
+ " "
+ )
rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
self.assertListEqual(py_ids, rust_ids)
@@ -168,10 +173,9 @@ def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 1000, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
- with self._large_tokenizer.as_target_tokenizer():
- targets = self._large_tokenizer(
- tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
- )
+ targets = self._large_tokenizer(
+ text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
+ )
assert batch.input_ids.shape == (2, 4096)
assert batch.attention_mask.shape == (2, 4096)
@@ -198,7 +202,10 @@ def test_equivalence_to_orig_tokenizer(self):
tokenizer.tokenize(test_str)
"""
- test_str = "This is an example string that is used to test the original TF implementation against the HF implementation"
+ test_str = (
+ "This is an example string that is used to test the original TF implementation against the HF"
+ " implementation"
+ )
token_ids = self._large_tokenizer(test_str).input_ids
diff --git a/tests/models/perceiver/test_modeling_perceiver.py b/tests/models/perceiver/test_modeling_perceiver.py
index 1fc102bc40a8..5947a73a0e41 100644
--- a/tests/models/perceiver/test_modeling_perceiver.py
+++ b/tests/models/perceiver/test_modeling_perceiver.py
@@ -542,9 +542,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. "
- f"Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. "
- f"Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -767,7 +770,10 @@ def test_problem_types(self):
@require_torch_multi_gpu
@unittest.skip(
- reason="Perceiver does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
+ reason=(
+ "Perceiver does not work with data parallel (DP) because of a bug in PyTorch:"
+ " https://github.com/pytorch/pytorch/issues/36035"
+ )
)
def test_multi_gpu_data_parallel_forward(self):
pass
diff --git a/tests/models/perceiver/test_tokenization_perceiver.py b/tests/models/perceiver/test_tokenization_perceiver.py
index ca61e9c856f1..3c7a67bcd2b9 100644
--- a/tests/models/perceiver/test_tokenization_perceiver.py
+++ b/tests/models/perceiver/test_tokenization_perceiver.py
@@ -146,10 +146,9 @@ def test_max_length_integration(self):
"Summary of the text.",
"Another summary.",
]
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(
- tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
- )
+ targets = tokenizer(
+ text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
+ )
self.assertEqual(32, targets["input_ids"].shape[1])
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py
index 073db546bf1a..171531503d2d 100644
--- a/tests/models/plbart/test_modeling_plbart.py
+++ b/tests/models/plbart/test_modeling_plbart.py
@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
diff --git a/tests/models/plbart/test_tokenization_plbart.py b/tests/models/plbart/test_tokenization_plbart.py
index 9aed6040f3fd..2ce7cafbda6e 100644
--- a/tests/models/plbart/test_tokenization_plbart.py
+++ b/tests/models/plbart/test_tokenization_plbart.py
@@ -299,33 +299,26 @@ def test_special_tokens_unaffacted_by_save_load(self):
@require_torch
def test_batch_fairseq_parity(self):
- batch = self.tokenizer(self.src_text, padding=True)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
+ batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
- self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
+ self.assertEqual(batch.input_ids[1][-2:].tolist(), [2, PYTHON_CODE])
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
- self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
+ self.assertEqual(batch.labels[1][-2:].tolist(), [2, EN_CODE])
@require_torch
def test_python_en_tokenizer_prepare_batch(self):
batch = self.tokenizer(
- self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
+ self.src_text,
+ text_target=self.tgt_text,
+ padding=True,
+ truncation=True,
+ max_length=len(self.expected_src_tokens),
+ return_tensors="pt",
)
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(
- self.tgt_text,
- padding=True,
- truncation=True,
- max_length=len(self.expected_src_tokens),
- return_tensors="pt",
- )
- labels = targets["input_ids"]
- batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
+ batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
@@ -340,8 +333,9 @@ def test_python_en_tokenizer_prepare_batch(self):
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
- with self.tokenizer.as_target_tokenizer():
- targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
+ targets = self.tokenizer(
+ text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
+ )
labels = targets["input_ids"]
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
diff --git a/tests/models/poolformer/test_modeling_poolformer.py b/tests/models/poolformer/test_modeling_poolformer.py
index 9bb8fa2e29cd..7dc47d2c77f9 100644
--- a/tests/models/poolformer/test_modeling_poolformer.py
+++ b/tests/models/poolformer/test_modeling_poolformer.py
@@ -142,6 +142,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ @unittest.skip(reason="PoolFormer does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip("PoolFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py
index e17e14072af5..9ac8ea81e20a 100644
--- a/tests/models/prophetnet/test_modeling_prophetnet.py
+++ b/tests/models/prophetnet/test_modeling_prophetnet.py
@@ -1226,7 +1226,15 @@ def test_cnndm_inference(self):
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-cnndm")
- ARTICLE_TO_SUMMARIZE = "USTC was founded in Beijing by the Chinese Academy of Sciences (CAS) in September 1958. The Director of CAS, Mr. Guo Moruo was appointed the first president of USTC. USTC's founding mission was to develop a high-level science and technology workforce, as deemed critical for development of China's economy, defense, and science and technology education. The establishment was hailed as \"A Major Event in the History of Chinese Education and Science.\" CAS has supported USTC by combining most of its institutes with the departments of the university. USTC is listed in the top 16 national key universities, becoming the youngest national key university.".lower()
+ ARTICLE_TO_SUMMARIZE = (
+ "USTC was founded in Beijing by the Chinese Academy of Sciences (CAS) in September 1958. The Director of"
+ " CAS, Mr. Guo Moruo was appointed the first president of USTC. USTC's founding mission was to develop a"
+ " high-level science and technology workforce, as deemed critical for development of China's economy,"
+ ' defense, and science and technology education. The establishment was hailed as "A Major Event in the'
+ ' History of Chinese Education and Science." CAS has supported USTC by combining most of its institutes'
+ " with the departments of the university. USTC is listed in the top 16 national key universities, becoming"
+ " the youngest national key university.".lower()
+ )
input_ids = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=511, return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
@@ -1234,7 +1242,10 @@ def test_cnndm_inference(self):
summary_ids = model.generate(
input_ids, num_beams=4, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True
)
- EXPECTED_SUMMARIZE_512 = "us ##tc was founded by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc is listed in the top 16 national key universities ."
+ EXPECTED_SUMMARIZE_512 = (
+ "us ##tc was founded by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc is listed in the"
+ " top 16 national key universities ."
+ )
generated_titles = [
" ".join(tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True)) for g in summary_ids
]
@@ -1251,7 +1262,8 @@ def test_cnndm_inference(self):
EXPECTED_SUMMARIZE_100 = (
r"us ##tc was founded in beijing by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc "
"'"
- ' s founding mission was to develop a high - level science and technology workforce . [X_SEP] establishment hailed as " a major event in the history of chinese education and science "'
+ " s founding mission was to develop a high - level science and technology workforce . [X_SEP]"
+ ' establishment hailed as " a major event in the history of chinese education and science "'
)
generated_titles = [
" ".join(tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True)) for g in summary_ids
diff --git a/tests/models/prophetnet/test_tokenization_prophetnet.py b/tests/models/prophetnet/test_tokenization_prophetnet.py
index 5b44879d04b5..8d95eb310025 100644
--- a/tests/models/prophetnet/test_tokenization_prophetnet.py
+++ b/tests/models/prophetnet/test_tokenization_prophetnet.py
@@ -141,7 +141,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/models/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py
index d9050acb6311..314ce099baf6 100644
--- a/tests/models/rag/test_modeling_tf_rag.py
+++ b/tests/models/rag/test_modeling_tf_rag.py
@@ -838,13 +838,6 @@ def test_rag_token_generate_batch(self):
input_ids = input_dict.input_ids
attention_mask = input_dict.attention_mask
- output_ids = rag_token.generate(
- input_ids,
- attention_mask=attention_mask,
- )
-
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
-
EXPECTED_OUTPUTS = [
" albert einstein",
" september 22, 2017",
@@ -855,7 +848,21 @@ def test_rag_token_generate_batch(self):
" 7.1. 2",
" 13",
]
- self.assertListEqual(outputs, EXPECTED_OUTPUTS)
+
+ # Split into 2 batches of 4 examples to avoid GPU OOM.
+ output_ids = rag_token.generate(
+ input_ids[:4],
+ attention_mask=attention_mask[:4],
+ )
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4])
+
+ output_ids = rag_token.generate(
+ input_ids[4:],
+ attention_mask=attention_mask[4:],
+ )
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])
@slow
def test_rag_sequence_generate_batch(self):
diff --git a/tests/models/realm/test_tokenization_realm.py b/tests/models/realm/test_tokenization_realm.py
index a54da2898032..2a065ceee66a 100644
--- a/tests/models/realm/test_tokenization_realm.py
+++ b/tests/models/realm/test_tokenization_realm.py
@@ -186,7 +186,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py
index 1929867521a3..0e5a801e7efb 100644
--- a/tests/models/reformer/test_modeling_reformer.py
+++ b/tests/models/reformer/test_modeling_reformer.py
@@ -574,7 +574,10 @@ def test_reformer_model_fp16_generate(self):
@require_torch_multi_gpu
@unittest.skip(
- reason="Reformer does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
+ reason=(
+ "Reformer does not work with data parallel (DP) because of a bug in PyTorch:"
+ " https://github.com/pytorch/pytorch/issues/36035"
+ )
)
def test_multi_gpu_data_parallel_forward(self):
pass
diff --git a/tests/models/reformer/test_tokenization_reformer.py b/tests/models/reformer/test_tokenization_reformer.py
index 32f946c49760..37ea66847f2d 100644
--- a/tests/models/reformer/test_tokenization_reformer.py
+++ b/tests/models/reformer/test_tokenization_reformer.py
@@ -214,7 +214,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
108,
265,
diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py
index 02695dbf6434..4879bf259efc 100644
--- a/tests/models/regnet/test_modeling_regnet.py
+++ b/tests/models/regnet/test_modeling_regnet.py
@@ -147,6 +147,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="RegNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/regnet/test_modeling_tf_regnet.py b/tests/models/regnet/test_modeling_tf_regnet.py
new file mode 100644
index 000000000000..c7504c92fa35
--- /dev/null
+++ b/tests/models/regnet/test_modeling_tf_regnet.py
@@ -0,0 +1,289 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TensorFlow RegNet model. """
+
+import inspect
+import unittest
+from typing import List, Tuple
+
+from transformers import RegNetConfig
+from transformers.testing_utils import require_tf, require_vision, slow
+from transformers.utils import cached_property, is_tf_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, TFRegNetForImageClassification, TFRegNetModel
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class TFRegNetModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ image_size=32,
+ num_channels=3,
+ embeddings_size=10,
+ hidden_sizes=[10, 20, 30, 40],
+ depths=[1, 1, 2, 1],
+ is_training=True,
+ use_labels=True,
+ hidden_act="relu",
+ num_labels=3,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.embeddings_size = embeddings_size
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_act = hidden_act
+ self.num_labels = num_labels
+ self.scope = scope
+ self.num_stages = len(hidden_sizes)
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return RegNetConfig(
+ num_channels=self.num_channels,
+ embeddings_size=self.embeddings_size,
+ hidden_sizes=self.hidden_sizes,
+ depths=self.depths,
+ hidden_act=self.hidden_act,
+ num_labels=self.num_labels,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFRegNetModel(config=config)
+ result = model(pixel_values, training=False)
+ # expected last hidden states: B, C, H // 32, W // 32
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = TFRegNetForImageClassification(config)
+ result = model(pixel_values, labels=labels, training=False)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_tf
+class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else ()
+
+ test_pruning = False
+ test_onnx = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = TFRegNetModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False)
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="RegNet does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skipIf(
+ not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
+ reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
+ )
+ def test_keras_fit(self):
+ pass
+
+ @unittest.skip(reason="RegNet does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="Model doesn't have attention layers")
+ def test_attention_outputs(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_stages = self.model_tester.num_stages
+ self.assertEqual(len(hidden_states), expected_num_stages + 1)
+
+ # RegNet's feature maps are of shape (batch_size, num_channels, height, width)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [self.model_tester.image_size // 2, self.model_tester.image_size // 2],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ layers_type = ["basic", "bottleneck"]
+ for model_class in self.all_model_classes:
+ for layer_type in layers_type:
+ config.layer_type = layer_type
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # Since RegNet does not have any attention we need to rewrite this test.
+ def test_model_outputs_equivalence(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+ tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
+ dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ all(tf.equal(tuple_object, dict_object)),
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
+ )
+
+ recursive_check(tuple_output, dict_output)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFRegNetModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+@require_vision
+class RegNetModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ AutoFeatureExtractor.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = TFRegNetForImageClassification.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(**inputs, training=False)
+
+ # verify the logits
+ expected_shape = tf.TensorShape((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant([-0.4180, -1.5051, -3.4836])
+
+ tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py
index f289c5c3df84..83f08b68afb8 100644
--- a/tests/models/resnet/test_modeling_resnet.py
+++ b/tests/models/resnet/test_modeling_resnet.py
@@ -147,6 +147,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="ResNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/resnet/test_modeling_tf_resnet.py b/tests/models/resnet/test_modeling_tf_resnet.py
new file mode 100644
index 000000000000..1056ebc8eeac
--- /dev/null
+++ b/tests/models/resnet/test_modeling_tf_resnet.py
@@ -0,0 +1,252 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the Tensorflow ResNet model. """
+
+
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import ResNetConfig
+from transformers.testing_utils import require_tf, require_vision, slow
+from transformers.utils import cached_property, is_tf_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import TFResNetForImageClassification, TFResNetModel
+ from transformers.models.resnet.modeling_tf_resnet import TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class TFResNetModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ image_size=32,
+ num_channels=3,
+ embeddings_size=10,
+ hidden_sizes=[10, 20, 30, 40],
+ depths=[1, 1, 2, 1],
+ is_training=True,
+ use_labels=True,
+ hidden_act="relu",
+ num_labels=3,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.embeddings_size = embeddings_size
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_act = hidden_act
+ self.num_labels = num_labels
+ self.scope = scope
+ self.num_stages = len(hidden_sizes)
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return ResNetConfig(
+ num_channels=self.num_channels,
+ embeddings_size=self.embeddings_size,
+ hidden_sizes=self.hidden_sizes,
+ depths=self.depths,
+ hidden_act=self.hidden_act,
+ num_labels=self.num_labels,
+ image_size=self.image_size,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFResNetModel(config=config)
+ result = model(pixel_values)
+ # expected last hidden states: B, C, H // 32, W // 32
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = TFResNetForImageClassification(config)
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_tf
+class TFResNetModelTest(TFModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else ()
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_onnx = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = TFResNetModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="ResNet does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="ResNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="ResNet does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_stages = self.model_tester.num_stages
+ self.assertEqual(len(hidden_states), expected_num_stages + 1)
+
+ # ResNet's feature maps are of shape (batch_size, num_channels, height, width)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [self.model_tester.image_size // 4, self.model_tester.image_size // 4],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ layers_type = ["basic", "bottleneck"]
+ for model_class in self.all_model_classes:
+ for layer_type in layers_type:
+ config.layer_type = layer_type
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFResNetModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+@require_vision
+class TFResNetModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ AutoFeatureExtractor.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = TFResNetForImageClassification.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = tf.TensorShape((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant([-11.1069, -9.7877, -8.3777])
+
+ self.assertTrue(np.allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4))
diff --git a/tests/models/retribert/__init__.py b/tests/models/retribert/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/retribert/test_tokenization_retribert.py b/tests/models/retribert/test_tokenization_retribert.py
new file mode 100644
index 000000000000..e2bf4e61b1ac
--- /dev/null
+++ b/tests/models/retribert/test_tokenization_retribert.py
@@ -0,0 +1,384 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the RetriBERT tokenizer. """
+
+
+import os
+import unittest
+
+from transformers import RetriBertTokenizer, RetriBertTokenizerFast
+from transformers.models.bert.tokenization_bert import (
+ VOCAB_FILES_NAMES,
+ BasicTokenizer,
+ WordpieceTokenizer,
+ _is_control,
+ _is_punctuation,
+ _is_whitespace,
+)
+from transformers.testing_utils import require_tokenizers, require_torch, slow
+
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
+
+
+# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
+@require_tokenizers
+class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = RetriBertTokenizer
+ test_slow_tokenizer = True
+ rust_tokenizer_class = RetriBertTokenizerFast
+ test_rust_tokenizer = True
+ space_between_special_tokens = True
+ from_pretrained_filter = filter_non_english
+
+ def setUp(self):
+ super().setUp()
+
+ vocab_tokens = [
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[PAD]",
+ "[MASK]",
+ "want",
+ "##want",
+ "##ed",
+ "wa",
+ "un",
+ "runn",
+ "##ing",
+ ",",
+ "low",
+ "lowest",
+ ]
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "UNwant\u00E9d,running"
+ output_text = "unwanted, running"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file)
+
+ tokens = tokenizer.tokenize("UNwant\u00E9d,running")
+ self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ # With lower casing
+ tokenizer = self.get_tokenizer(do_lower_case=True)
+ rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_chinese(self):
+ tokenizer = BasicTokenizer()
+
+ self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
+
+ def test_basic_tokenizer_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hƤllo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
+
+ def test_basic_tokenizer_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_default(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_no_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HƤLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_respects_never_split_tokens(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
+ )
+
+ def test_wordpiece_tokenizer(self):
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
+
+ vocab = {}
+ for i, token in enumerate(vocab_tokens):
+ vocab[token] = i
+ tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
+
+ self.assertListEqual(tokenizer.tokenize(""), [])
+
+ self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
+
+ self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+
+ def test_is_whitespace(self):
+ self.assertTrue(_is_whitespace(" "))
+ self.assertTrue(_is_whitespace("\t"))
+ self.assertTrue(_is_whitespace("\r"))
+ self.assertTrue(_is_whitespace("\n"))
+ self.assertTrue(_is_whitespace("\u00A0"))
+
+ self.assertFalse(_is_whitespace("A"))
+ self.assertFalse(_is_whitespace("-"))
+
+ def test_is_control(self):
+ self.assertTrue(_is_control("\u0005"))
+
+ self.assertFalse(_is_control("A"))
+ self.assertFalse(_is_control(" "))
+ self.assertFalse(_is_control("\t"))
+ self.assertFalse(_is_control("\r"))
+
+ def test_is_punctuation(self):
+ self.assertTrue(_is_punctuation("-"))
+ self.assertTrue(_is_punctuation("$"))
+ self.assertTrue(_is_punctuation("`"))
+ self.assertTrue(_is_punctuation("."))
+
+ self.assertFalse(_is_punctuation("A"))
+ self.assertFalse(_is_punctuation(" "))
+
+ def test_clean_text(self):
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
+ self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])
+
+ self.assertListEqual(
+ [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
+ )
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("yjernite/retribert-base-uncased")
+
+ text = tokenizer.encode("sequence builders", add_special_tokens=False)
+ text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
+
+ encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_sentence == [101] + text + [102]
+ assert encoded_pair == [101] + text + [102] + text_2 + [102]
+
+ def test_offsets_with_special_characters(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ sentence = f"A, naĆÆve {tokenizer_r.mask_token} AllenNLP sentence."
+ tokens = tokenizer_r.encode_plus(
+ sentence,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+
+ do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
+ expected_results = (
+ [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "A"),
+ ((1, 2), ","),
+ ((3, 5), "na"),
+ ((5, 6), "##ĆÆ"),
+ ((6, 8), "##ve"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "Allen"),
+ ((21, 23), "##NL"),
+ ((23, 24), "##P"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ if not do_lower_case
+ else [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "a"),
+ ((1, 2), ","),
+ ((3, 8), "naive"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "allen"),
+ ((21, 23), "##nl"),
+ ((23, 24), "##p"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ )
+
+ self.assertEqual(
+ [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
+ )
+ self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+
+ def test_change_tokenize_chinese_chars(self):
+ list_of_commun_chinese_char = ["ē", "äŗŗ", "ę"]
+ text_with_chinese_char = "".join(list_of_commun_chinese_char)
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ kwargs["tokenize_chinese_chars"] = True
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that each Chinese character is not preceded by "##"
+ self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
+ self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
+
+ kwargs["tokenize_chinese_chars"] = False
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that only the first Chinese character is not preceded by "##".
+ expected_tokens = [
+ f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
+ ]
+ self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
+ self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
+
+ # RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
+ @require_torch
+ @slow
+ def test_torch_encode_plus_sent_to_model(self):
+ import torch
+
+ from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
+
+ MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
+
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
+ return
+
+ config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
+ config = config_class()
+
+ if config.is_encoder_decoder or config.pad_token_id is None:
+ return
+
+ model = model_class(config)
+
+ # The following test is different from the common's one
+ self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
+
+ # Build sequence
+ first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
+ sequence = " ".join(first_ten_tokens)
+ encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
+
+ # Ensure that the BatchEncoding.to() method works.
+ encoded_sequence.to(model.device)
+
+ batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
+ # This should not fail
+
+ with torch.no_grad(): # saves some time
+ # The following lines are different from the common's ones
+ model.embed_questions(**encoded_sequence)
+ model.embed_questions(**batch_encoded_sequence)
diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py
index e0b8b78b3b6c..7163a357021e 100644
--- a/tests/models/roberta/test_modeling_roberta.py
+++ b/tests/models/roberta/test_modeling_roberta.py
@@ -112,6 +112,11 @@ def get_config(self):
initializer_range=self.initializer_range,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/models/segformer/test_modeling_segformer.py b/tests/models/segformer/test_modeling_segformer.py
index 9af59299f8ec..6a1d273f6642 100644
--- a/tests/models/segformer/test_modeling_segformer.py
+++ b/tests/models/segformer/test_modeling_segformer.py
@@ -18,7 +18,7 @@
import inspect
import unittest
-from transformers import is_torch_available, is_vision_available
+from transformers import SegformerConfig, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
@@ -31,7 +31,6 @@
from transformers import (
MODEL_MAPPING,
- SegformerConfig,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
SegformerModel,
diff --git a/tests/models/segformer/test_modeling_tf_segformer.py b/tests/models/segformer/test_modeling_tf_segformer.py
new file mode 100644
index 000000000000..d6a73e22192c
--- /dev/null
+++ b/tests/models/segformer/test_modeling_tf_segformer.py
@@ -0,0 +1,504 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TensorFlow SegFormer model. """
+
+import inspect
+import unittest
+from typing import List, Tuple
+
+from transformers import SegformerConfig
+from transformers.file_utils import is_tf_available, is_vision_available
+from transformers.testing_utils import require_tf, slow
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import numpy as np
+ import tensorflow as tf
+
+ from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
+ from transformers.models.segformer.modeling_tf_segformer import TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import SegformerFeatureExtractor
+
+
+class TFSegformerConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
+ self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
+
+
+class TFSegformerModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=64,
+ num_channels=3,
+ num_encoder_blocks=4,
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1],
+ hidden_sizes=[16, 32, 64, 128],
+ downsampling_rates=[1, 4, 8, 16],
+ num_attention_heads=[1, 2, 4, 8],
+ is_training=True,
+ use_labels=True,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.num_encoder_blocks = num_encoder_blocks
+ self.sr_ratios = sr_ratios
+ self.depths = depths
+ self.hidden_sizes = hidden_sizes
+ self.downsampling_rates = downsampling_rates
+ self.num_attention_heads = num_attention_heads
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return SegformerConfig(
+ image_size=self.image_size,
+ num_channels=self.num_channels,
+ num_encoder_blocks=self.num_encoder_blocks,
+ depths=self.depths,
+ hidden_sizes=self.hidden_sizes,
+ num_attention_heads=self.num_attention_heads,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ num_labels=self.num_labels,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFSegformerModel(config=config)
+ result = model(pixel_values, training=False)
+ expected_height = expected_width = self.image_size // (self.downsampling_rates[-1] * 2)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.hidden_sizes[-1], expected_height, expected_width)
+ )
+
+ def create_and_check_for_image_segmentation(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = TFSegformerForSemanticSegmentation(config)
+ result = model(pixel_values, training=False)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
+ )
+ result = model(pixel_values, labels=labels, training=False)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_keras_fit(self, for_segmentation: bool = False):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, seg_labels = config_and_inputs
+ if for_segmentation:
+ inputs_dict = {"pixel_values": pixel_values, "labels": seg_labels}
+ else:
+ inputs_dict = {"pixel_values": pixel_values, "labels": tf.zeros((self.batch_size))}
+ return config, inputs_dict
+
+
+@require_tf
+class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)
+ if is_tf_available()
+ else ()
+ )
+
+ test_head_masking = False
+ test_onnx = False
+ test_pruning = False
+ test_resize_embeddings = False
+
+ def setUp(self):
+ self.model_tester = TFSegformerModelTester(self)
+ self.config_tester = TFSegformerConfigTester(self, config_class=SegformerConfig, has_text_modality=False)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip("SegFormer does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip("SegFormer does not have get_input_embeddings method and get_output_embeddings methods")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip("Test was written for TF 1.x and isn't really relevant here")
+ def test_compile_tf_model(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+
+ expected_num_attentions = sum(self.model_tester.depths)
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # verify the first attentions (first block, first layer)
+ expected_seq_len = (self.model_tester.image_size // 4) ** 2
+ expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
+ )
+
+ # verify the last attentions (last block, last layer)
+ expected_seq_len = (self.model_tester.image_size // 32) ** 2
+ expected_reduced_seq_len = (self.model_tester.image_size // (32 * self.model_tester.sr_ratios[-1])) ** 2
+ self.assertListEqual(
+ list(attentions[-1].shape[-3:]),
+ [self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ self.assertEqual(out_len + 1, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+ # verify the first attentions (first block, first layer)
+ expected_seq_len = (self.model_tester.image_size // 4) ** 2
+ expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = self.model_tester.num_encoder_blocks
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # verify the first hidden states (first block)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-3:]),
+ [
+ self.model_tester.hidden_sizes[0],
+ self.model_tester.image_size // 4,
+ self.model_tester.image_size // 4,
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_model_outputs_equivalence(self):
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+ tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
+ dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ all(tf.equal(tuple_object, dict_object)),
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
+ )
+
+ recursive_check(tuple_output, dict_output)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+
+ # todo: incorporate label support for semantic segmentation in `test_modeling_tf_common.py`.
+
+ @unittest.skipIf(
+ not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
+ reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
+ )
+ def test_dataset_conversion(self):
+ super().test_dataset_conversion()
+
+ def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
+ self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
+
+ @unittest.skipIf(
+ not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
+ reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
+ )
+ def test_keras_fit(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ # Since `TFSegformerModel` cannot operate with the default `fit()` method.
+ if model_class.__name__ != "TFSegformerModel":
+ model = model_class(config)
+ if getattr(model, "hf_compute_loss", None):
+ super().test_keras_fit()
+
+ def test_loss_computation(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def apply(model):
+ for_segmentation = True if model_class.__name__ == "TFSegformerForSemanticSegmentation" else False
+ # The number of elements in the loss should be the same as the number of elements in the label
+ _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
+ for_segmentation=for_segmentation
+ )
+ added_label = prepared_for_class[
+ sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
+ ]
+ loss_size = tf.size(added_label)
+
+ # Test that model correctly compute the loss with kwargs
+ possible_input_names = {"input_ids", "pixel_values", "input_features"}
+ input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
+ model_input = prepared_for_class.pop(input_name)
+
+ loss = model(model_input, **prepared_for_class)[0]
+
+ if model_class.__name__ == "TFSegformerForSemanticSegmentation":
+ # Semantic segmentation loss is computed similarly as
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210.
+ self.assertEqual(loss.shape, (1,))
+ else:
+ self.assertEqual(loss.shape, [loss_size])
+
+ # Test that model correctly compute the loss with a dict
+ _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
+ for_segmentation=for_segmentation
+ )
+ loss = model(**prepared_for_class)[0]
+
+ if model_class.__name__ == "TFSegformerForSemanticSegmentation":
+ self.assertEqual(loss.shape, (1,))
+ else:
+ self.assertEqual(loss.shape, [loss_size])
+
+ # Test that model correctly compute the loss with a tuple
+ label_keys = prepared_for_class.keys() - inputs_dict.keys()
+ signature = inspect.signature(model.call).parameters
+ signature_names = list(signature.keys())
+
+ # Create a dictionary holding the location of the tensors in the tuple
+ tuple_index_mapping = {0: input_name}
+ for label_key in label_keys:
+ label_key_index = signature_names.index(label_key)
+ tuple_index_mapping[label_key_index] = label_key
+ sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
+ # Initialize a list with their default values, update the values and convert to a tuple
+ list_input = []
+
+ for name in signature_names:
+ if name != "kwargs":
+ list_input.append(signature[name].default)
+
+ for index, value in sorted_tuple_index_mapping:
+ list_input[index] = prepared_for_class[value]
+
+ tuple_input = tuple(list_input)
+
+ # Send to model
+ loss = model(tuple_input[:-1])[0]
+ if model_class.__name__ == "TFSegformerForSemanticSegmentation":
+ self.assertEqual(loss.shape, (1,))
+ else:
+ self.assertEqual(loss.shape, [loss_size])
+
+ for model_class in self.all_model_classes:
+ # Since `TFSegformerModel` won't have labels against which we
+ # could compute loss.
+ if model_class.__name__ != "TFSegformerModel":
+ model = model_class(config)
+ apply(model)
+
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
+ # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
+ super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFSegformerModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+class TFSegformerModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference_image_segmentation_ade(self):
+ # only resize + normalize
+ feature_extractor = SegformerFeatureExtractor(
+ image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
+ )
+ model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
+
+ image = prepare_img()
+ encoded_inputs = feature_extractor(images=image, return_tensors="tf")
+ pixel_values = encoded_inputs.pixel_values
+
+ outputs = model(pixel_values, training=False)
+
+ expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant(
+ [
+ [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
+ [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
+ [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
+ ]
+ )
+ tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)
+
+ @slow
+ def test_inference_image_segmentation_city(self):
+ # only resize + normalize
+ feature_extractor = SegformerFeatureExtractor(
+ image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
+ )
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
+ "nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
+ )
+
+ image = prepare_img()
+ encoded_inputs = feature_extractor(images=image, return_tensors="tf")
+ pixel_values = encoded_inputs.pixel_values
+
+ outputs = model(pixel_values, training=False)
+
+ expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant(
+ [
+ [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
+ [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
+ [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
+ ]
+ )
+ tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)
diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py
index 08b94b6465d5..a1a625a9b403 100644
--- a/tests/models/speech_to_text/test_modeling_speech_to_text.py
+++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py
@@ -17,6 +17,7 @@
import copy
import inspect
import os
+import pickle
import tempfile
import unittest
@@ -30,7 +31,7 @@
slow,
torch_device,
)
-from transformers.utils import cached_property
+from transformers.utils import cached_property, is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -43,6 +44,9 @@
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
def prepare_speech_to_text_inputs_dict(
config,
@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -715,6 +720,105 @@ def _create_and_check_torchscript(self, config, inputs_dict):
self.assertTrue(models_equal)
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "input_features",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
@require_torch
@require_torchaudio
@@ -770,8 +874,10 @@ def test_generation_librispeech_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister cultar's manner less interesting than his matter",
- "he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
- "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
+ "he tells us that at this festive season of the year with christmas and roast beef looming before us"
+ " similes drawn from eating and its results occur most readily to the mind",
+ "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it"
+ " but little of rocky ithaca",
]
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
index 6485690645a9..613af6be0cd0 100644
--- a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
+++ b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
@@ -602,7 +602,9 @@ def test_generation_librispeech_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister cultar's manner less interesting than his matter",
- "he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
- "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
+ "he tells us that at this festive season of the year with christmas and roast beef looming before us"
+ " similes drawn from eating and its results occur most readily to the mind",
+ "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it"
+ " but little of rocky ithaca",
]
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/speech_to_text/test_processor_speech_to_text.py b/tests/models/speech_to_text/test_processor_speech_to_text.py
index e6e43f1bb8d7..d519f005d3eb 100644
--- a/tests/models/speech_to_text/test_processor_speech_to_text.py
+++ b/tests/models/speech_to_text/test_processor_speech_to_text.py
@@ -125,8 +125,7 @@ def test_tokenizer(self):
input_str = "This is a test string"
- with processor.as_target_processor():
- encoded_processor = processor(input_str)
+ encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
diff --git a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
index 88d055067575..d9717b406049 100644
--- a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
+++ b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
@@ -179,6 +179,7 @@ def prepare_config_and_inputs_for_common(self):
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
def setUp(
diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py
index 9b62b822c098..f064611b6a9e 100644
--- a/tests/models/splinter/test_modeling_splinter.py
+++ b/tests/models/splinter/test_modeling_splinter.py
@@ -14,11 +14,11 @@
# limitations under the License.
""" Testing suite for the PyTorch Splinter model. """
-
+import copy
import unittest
from transformers import is_torch_available
-from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -27,7 +27,7 @@
if is_torch_available():
import torch
- from transformers import SplinterConfig, SplinterForQuestionAnswering, SplinterModel
+ from transformers import SplinterConfig, SplinterForPreTraining, SplinterForQuestionAnswering, SplinterModel
from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -36,6 +36,7 @@ def __init__(
self,
parent,
batch_size=13,
+ num_questions=3,
seq_length=7,
is_training=True,
use_input_mask=True,
@@ -43,6 +44,7 @@ def __init__(
use_labels=True,
vocab_size=99,
hidden_size=32,
+ question_token_id=1,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
@@ -59,6 +61,7 @@ def __init__(
):
self.parent = parent
self.batch_size = batch_size
+ self.num_questions = num_questions
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
@@ -66,6 +69,7 @@ def __init__(
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
+ self.question_token_id = question_token_id
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
@@ -82,6 +86,7 @@ def __init__(
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_ids[:, 1] = self.question_token_id
input_mask = None
if self.use_input_mask:
@@ -91,13 +96,13 @@ def prepare_config_and_inputs(self):
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
- sequence_labels = None
- token_labels = None
- choice_labels = None
+ start_positions = None
+ end_positions = None
+ question_positions = None
if self.use_labels:
- sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
- token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
- choice_labels = ids_tensor([self.batch_size], self.num_choices)
+ start_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
+ end_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
+ question_positions = ids_tensor([self.batch_size, self.num_questions], self.num_labels)
config = SplinterConfig(
vocab_size=self.vocab_size,
@@ -112,12 +117,20 @@ def prepare_config_and_inputs(self):
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
+ question_token_id=self.question_token_id,
)
- return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ return (config, input_ids, token_type_ids, input_mask, start_positions, end_positions, question_positions)
def create_and_check_model(
- self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
):
model = SplinterModel(config=config)
model.to(torch_device)
@@ -128,7 +141,14 @@ def create_and_check_model(
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_question_answering(
- self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
):
model = SplinterForQuestionAnswering(config=config)
model.to(torch_device)
@@ -137,12 +157,36 @@ def create_and_check_for_question_answering(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
- start_positions=sequence_labels,
- end_positions=sequence_labels,
+ start_positions=start_positions[:, 0],
+ end_positions=end_positions[:, 0],
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+ def create_and_check_for_pretraining(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
+ ):
+ model = SplinterForPreTraining(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -150,11 +194,15 @@ def prepare_config_and_inputs_for_common(self):
input_ids,
token_type_ids,
input_mask,
- sequence_labels,
- token_labels,
- choice_labels,
+ start_positions,
+ end_positions,
+ question_positions,
) = config_and_inputs
- inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ inputs_dict = {
+ "input_ids": input_ids,
+ "token_type_ids": token_type_ids,
+ "attention_mask": input_mask,
+ }
return config, inputs_dict
@@ -165,11 +213,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
(
SplinterModel,
SplinterForQuestionAnswering,
+ SplinterForPreTraining,
)
if is_torch_available()
else ()
)
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = copy.deepcopy(inputs_dict)
+ if return_labels:
+ if issubclass(model_class, SplinterForPreTraining):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ inputs_dict["question_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ elif issubclass(model_class, SplinterForQuestionAnswering):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+
+ return inputs_dict
+
def setUp(self):
self.model_tester = SplinterModelTester(self)
self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37)
@@ -191,12 +272,86 @@ def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+ def test_for_pretraining(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ if isinstance(model, SplinterForPreTraining):
+ with self.assertRaises(TypeError):
+ # question_positions must not be None.
+ model(**inputs)[0]
+ else:
+ model(**inputs)[0]
+
@slow
def test_model_from_pretrained(self):
for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SplinterModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ # overwrite from common since `SplinterForPreTraining` could contain different number of question tokens in inputs.
+ # When the batch is distributed to multiple devices, each replica could get different values for the maximal number
+ # of question tokens (see `SplinterForPreTraining._prepare_question_positions()`), and the model returns different
+ # shape along dimension 1 (i.e. `num_questions`) that could not be combined into a single tensor as an output.
+ @require_torch_multi_gpu
+ def test_multi_gpu_data_parallel_forward(self):
+ from torch import nn
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # some params shouldn't be scattered by nn.DataParallel
+ # so just remove them if they are present.
+ blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
+ for k in blacklist_non_batched_params:
+ inputs_dict.pop(k, None)
+
+ # move input tensors to cuda:O
+ for k, v in inputs_dict.items():
+ if torch.is_tensor(v):
+ inputs_dict[k] = v.to(0)
+
+ for model_class in self.all_model_classes:
+
+ # Skip this case since it will fail sometimes, as described above.
+ if model_class == SplinterForPreTraining:
+ continue
+
+ model = model_class(config=config)
+ model.to(0)
+ model.eval()
+
+ # Wrap model in nn.DataParallel
+ model = nn.DataParallel(model)
+ with torch.no_grad():
+ _ = model(**self._prepare_for_class(inputs_dict, model_class))
+
@require_torch
class SplinterModelIntegrationTest(unittest.TestCase):
@@ -217,3 +372,122 @@ def test_splinter_question_answering(self):
self.assertEqual(torch.argmax(output.start_logits), 10)
self.assertEqual(torch.argmax(output.end_logits), 12)
+
+ @slow
+ def test_splinter_pretraining(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
+ )
+ question_positions = torch.tensor([[1, 5]], dtype=torch.long)
+ output = model(input_ids, question_positions=question_positions)
+
+ expected_shape = torch.Size((1, 2, 16))
+ self.assertEqual(output.start_logits.shape, expected_shape)
+ self.assertEqual(output.end_logits.shape, expected_shape)
+
+ self.assertEqual(torch.argmax(output.start_logits[0, 0]), 7)
+ self.assertEqual(torch.argmax(output.end_logits[0, 0]), 7)
+ self.assertEqual(torch.argmax(output.start_logits[0, 1]), 10)
+ self.assertEqual(torch.argmax(output.end_logits[0, 1]), 12)
+
+ @slow
+ def test_splinter_pretraining_loss_requires_question_positions(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
+ )
+ start_positions = torch.tensor([[7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([7, 12], dtype=torch.long)
+ with self.assertRaises(TypeError):
+ model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ )
+
+ @slow
+ def test_splinter_pretraining_loss(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ ]
+ )
+ start_positions = torch.tensor([[7, 10], [7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([[7, 12], [7, 12]], dtype=torch.long)
+ question_positions = torch.tensor([[1, 5], [1, 5]], dtype=torch.long)
+ output = model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ self.assertAlmostEqual(output.loss.item(), 0.0024, 4)
+
+ @slow
+ def test_splinter_pretraining_loss_with_padding(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ ]
+ )
+ start_positions = torch.tensor([[7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([7, 12], dtype=torch.long)
+ question_positions = torch.tensor([[1, 5]], dtype=torch.long)
+ start_positions_with_padding = torch.tensor([[7, 10, 0]], dtype=torch.long)
+ end_positions_with_padding = torch.tensor([7, 12, 0], dtype=torch.long)
+ question_positions_with_padding = torch.tensor([[1, 5, 0]], dtype=torch.long)
+ output = model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ output_with_padding = model(
+ input_ids,
+ start_positions=start_positions_with_padding,
+ end_positions=end_positions_with_padding,
+ question_positions=question_positions_with_padding,
+ )
+
+ self.assertAlmostEqual(output.loss.item(), output_with_padding.loss.item(), 4)
+
+ # Note that the original code uses 0 to denote padded question tokens
+ # and their start and end positions. As the pad_token_id of the model's
+ # config is used for the losse's ignore_index in SplinterForPreTraining,
+ # we add this test to ensure anybody making changes to the default
+ # value of the config, will be aware of the implication.
+ self.assertEqual(model.config.pad_token_id, 0)
+
+ @slow
+ def test_splinter_pretraining_prepare_question_positions(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1, 2, 104, 3, 4, 102],
+ [101, 1, 104, 2, 104, 3, 104, 102],
+ [101, 1, 2, 104, 104, 3, 4, 102],
+ [101, 1, 2, 3, 4, 5, 104, 102],
+ ]
+ )
+ question_positions = torch.tensor([[1, 4, 0], [2, 4, 6], [3, 4, 0], [6, 0, 0]], dtype=torch.long)
+ output_without_positions = model(input_ids)
+ output_with_positions = model(input_ids, question_positions=question_positions)
+ self.assertTrue((output_without_positions.start_logits == output_with_positions.start_logits).all())
+ self.assertTrue((output_without_positions.end_logits == output_with_positions.end_logits).all())
diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py
index ef7a64e998d7..5e07efa2a3dc 100644
--- a/tests/models/swin/test_modeling_swin.py
+++ b/tests/models/swin/test_modeling_swin.py
@@ -14,16 +14,19 @@
# limitations under the License.
""" Testing suite for the PyTorch Swin model. """
-import copy
+import collections
import inspect
+import os
+import pickle
+import tempfile
import unittest
from transformers import SwinConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
-from transformers.utils import cached_property, is_torch_available, is_vision_available
+from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
@@ -31,20 +34,15 @@
from torch import nn
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
- from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
+ from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
-
-def _config_zero_init(config):
- configs_no_init = copy.deepcopy(config)
- for key in configs_no_init.__dict__.keys():
- if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
- setattr(configs_no_init, key, 1e-10)
- return configs_no_init
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
class SwinModelTester:
@@ -144,6 +142,25 @@ def create_and_check_model(self, config, pixel_values, labels):
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = SwinForMaskedImageModeling(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = SwinForMaskedImageModeling(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = SwinForImageClassification(config)
@@ -152,6 +169,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = SwinForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -175,6 +202,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = False
@@ -200,6 +228,14 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
def test_inputs_embeds(self):
# Swin does not use inputs_embeds
pass
@@ -285,60 +321,92 @@ def test_attention_outputs(self):
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- hidden_states = outputs.hidden_states
+ hidden_states = outputs.hidden_states
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
- # Swin has a different seq_length
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
+ # Swin has a different seq_length
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [num_patches, self.model_tester.embed_dim],
- )
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
- reshaped_hidden_states = outputs.reshaped_hidden_states
- self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
- batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
- reshaped_hidden_states = (
- reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
- )
- self.assertListEqual(
- list(reshaped_hidden_states.shape[-2:]),
- [num_patches, self.model_tester.embed_dim],
- )
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
+ reshaped_hidden_states = (
+ reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ )
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+ def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
- check_hidden_states_output(inputs_dict, config, model_class)
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
- def test_for_image_classification(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+ def test_hidden_states_output_with_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
@slow
def test_model_from_pretrained(self):
@@ -360,6 +428,99 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
@require_vision
@require_torch
diff --git a/tests/models/swin/test_modeling_tf_swin.py b/tests/models/swin/test_modeling_tf_swin.py
new file mode 100644
index 000000000000..be5861ce48b4
--- /dev/null
+++ b/tests/models/swin/test_modeling_tf_swin.py
@@ -0,0 +1,405 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TF 2.0 Swin model. """
+
+
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import SwinConfig
+from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple, tooslow
+from transformers.utils import cached_property, is_tf_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers.models.swin.modeling_tf_swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ )
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class TFSwinModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ embed_dim=16,
+ depths=[1, 2, 1],
+ num_heads=[2, 2, 4],
+ window_size=2,
+ mlp_ratio=2.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ is_training=True,
+ scope=None,
+ use_labels=True,
+ type_sequence_label_size=10,
+ encoder_stride=8,
+ ) -> None:
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.patch_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.is_training = is_training
+ self.scope = scope
+ self.use_labels = use_labels
+ self.type_sequence_label_size = type_sequence_label_size
+ self.encoder_stride = encoder_stride
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return SwinConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ depths=self.depths,
+ num_heads=self.num_heads,
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ drop_path_rate=self.drop_path_rate,
+ hidden_act=self.hidden_act,
+ use_absolute_embeddings=self.use_absolute_embeddings,
+ path_norm=self.patch_norm,
+ layer_norm_eps=self.layer_norm_eps,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFSwinModel(config=config)
+ result = model(pixel_values)
+
+ expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
+ expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
+
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = TFSwinForMaskedImageModeling(config=config)
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = TFSwinForMaskedImageModeling(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = TFSwinForImageClassification(config)
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = TFSwinForImageClassification(config)
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_tf
+class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ TFSwinModel,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ )
+ if is_tf_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_onnx = False
+
+ def setUp(self):
+ self.model_tester = TFSwinModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=SwinConfig, embed_dim=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @unittest.skip(reason="Swin does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @tooslow
+ def test_saved_model_creation(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ expected_num_attentions = len(self.model_tester.depths)
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ window_size_squared = config.window_size**2
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ else:
+ # also another +1 for reshaped_hidden_states
+ added_hidden_states = 2
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # Swin has a different seq_length
+ patch_size = to_2tuple(config.patch_size)
+
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
+
+ reshaped_hidden_states = tf.reshape(reshaped_hidden_states[0], (batch_size, num_channels, height * width))
+ reshaped_hidden_states = tf.transpose(reshaped_hidden_states, (0, 2, 1))
+
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ def test_hidden_states_output(self):
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ image_size = to_2tuple(self.model_tester.image_size)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ def test_inputs_requiring_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = to_2tuple(self.model_tester.image_size)
+ patch_size = to_2tuple(config.patch_size)
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFSwinModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_vision
+@require_tf
+class TFSwinModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = TFSwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+ feature_extractor = self.default_feature_extractor
+
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(inputs)
+
+ # verify the logits
+ expected_shape = tf.TensorShape((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+ expected_slice = tf.constant([-0.0948, -0.6454, -0.0921])
+ self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/swinv2/__init__.py b/tests/models/swinv2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py
new file mode 100644
index 000000000000..13a39b139c81
--- /dev/null
+++ b/tests/models/swinv2/test_modeling_swinv2.py
@@ -0,0 +1,430 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Swinv2 model. """
+import collections
+import inspect
+import unittest
+
+from transformers import Swinv2Config
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model
+ from transformers.models.swinv2.modeling_swinv2 import SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class Swinv2ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ embed_dim=16,
+ depths=[1, 2, 1],
+ num_heads=[2, 2, 4],
+ window_size=2,
+ mlp_ratio=2.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ is_training=True,
+ scope=None,
+ use_labels=True,
+ type_sequence_label_size=10,
+ encoder_stride=8,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.patch_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.is_training = is_training
+ self.scope = scope
+ self.use_labels = use_labels
+ self.type_sequence_label_size = type_sequence_label_size
+ self.encoder_stride = encoder_stride
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return Swinv2Config(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ depths=self.depths,
+ num_heads=self.num_heads,
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ drop_path_rate=self.drop_path_rate,
+ hidden_act=self.hidden_act,
+ use_absolute_embeddings=self.use_absolute_embeddings,
+ path_norm=self.patch_norm,
+ layer_norm_eps=self.layer_norm_eps,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = Swinv2Model(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
+ expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
+
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = Swinv2ForMaskedImageModeling(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = Swinv2ForMaskedImageModeling(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = Swinv2ForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class Swinv2ModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else ()
+ )
+
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = Swinv2ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Swinv2Config, embed_dim=37)
+
+ def test_config(self):
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="Swinv2 does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ expected_num_attentions = len(self.model_tester.depths)
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ window_size_squared = config.window_size**2
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ else:
+ # also another +1 for reshaped_hidden_states
+ added_hidden_states = 2
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # Swinv2 has a different seq_length
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
+ reshaped_hidden_states = (
+ reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ )
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ def test_hidden_states_output(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ def test_hidden_states_output_with_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = Swinv2Model.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if "embeddings" not in name and "logit_scale" not in name and param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+
+@require_vision
+@require_torch
+class Swinv2ModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ AutoFeatureExtractor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = Swinv2ForImageClassification.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(
+ torch_device
+ )
+ feature_extractor = self.default_feature_extractor
+
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+ expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py
index 7971bb4116df..3186567709d2 100644
--- a/tests/models/t5/test_modeling_flax_t5.py
+++ b/tests/models/t5/test_modeling_flax_t5.py
@@ -48,7 +48,12 @@
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
- from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
+ from transformers.models.t5.modeling_flax_t5 import (
+ FlaxT5EncoderModel,
+ FlaxT5ForConditionalGeneration,
+ FlaxT5Model,
+ shift_tokens_right,
+ )
class FlaxT5ModelTester:
@@ -461,6 +466,298 @@ def test_save_load_bf16_to_base_pt(self):
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+class FlaxT5EncoderOnlyModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ # For common tests
+ is_training=True,
+ use_attention_mask=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ eos_token_id=1,
+ pad_token_id=0,
+ decoder_start_token_id=0,
+ scope=None,
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ # For common tests
+ self.seq_length = self.encoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.scope = None
+ self.decoder_layers = 0
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+
+ config = T5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ is_encoder_decoder=False,
+ )
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ ):
+ model = FlaxT5EncoderModel(config=config)
+ result = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+ result = model(input_ids=input_ids)
+ encoder_output = result.last_hidden_state
+
+ self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_flax
+class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlaxT5EncoderModel,) if is_flax_available() else ()
+ is_encoder_decoder = False
+
+ def setUp(self):
+ self.model_tester = FlaxT5EncoderOnlyModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_v1_1(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ # check that gated gelu feed forward and different word embeddings work
+ config = config_and_inputs[0]
+ config.tie_word_embeddings = False
+ config.feed_forward_proj = "gated-gelu"
+ self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
+
+ def test_encode(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ with self.subTest(model_class.__name__):
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+
+ @jax.jit
+ def encode_jitted(input_ids, attention_mask=None, **kwargs):
+ return model(input_ids=input_ids, attention_mask=attention_mask)
+
+ with self.subTest("JIT Enabled"):
+ jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ with self.subTest("JIT Disabled"):
+ with jax.disable_jit():
+ outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ self.assertEqual(len(outputs), len(jitted_outputs))
+ for jitted_output, output in zip(jitted_outputs, outputs):
+ self.assertEqual(jitted_output.shape, output.shape)
+
+ # overwrite since special base model prefix is used
+ def test_save_load_from_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ def test_save_load_to_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_from_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # save pt model
+ pt_model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_bf16_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ model.params = model.to_bf16(model.params)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+
@require_sentencepiece
@require_tokenizers
@require_flax
@@ -573,16 +870,208 @@ def test_summarization(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
tok = T5Tokenizer.from_pretrained("t5-base")
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says . all 150 on board were killed when germanwings flight 9525 crashed .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut"
+ " centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
dct = tok(
diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py
index d57300418473..3ed5521a62d0 100644
--- a/tests/models/t5/test_modeling_t5.py
+++ b/tests/models/t5/test_modeling_t5.py
@@ -509,12 +509,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
- fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = True
test_model_parallel = True
is_encoder_decoder = True
+ # The small T5 model needs higher percentages for CPU/MP tests
+ model_split_percents = [0.8, 0.9]
def setUp(self):
self.model_tester = T5ModelTester(self)
@@ -539,6 +541,12 @@ def test_model_v1_1(self):
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
+ def test_config_and_model_silu_gated(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config = config_and_inputs[0]
+ config.feed_forward_proj = "gated-silu"
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
@@ -654,6 +662,10 @@ def test_generate_with_head_masking(self):
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
+ @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
+ def test_disk_offload(self):
+ pass
+
class T5EncoderOnlyModelTester:
def __init__(
@@ -909,16 +921,208 @@ def test_summarization(self):
model = self.model
tok = self.tokenizer
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
+ " implement a rigorous inspection regime .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
use_task_specific_params(model, "summarization")
@@ -971,7 +1175,10 @@ def test_translation_en_to_fr(self):
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_fr")
- en_text = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '
+ en_text = (
+ ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
+ " countless generations of stars: the oldest stars are seen as blue dots. "
+ )
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
input_ids = input_ids.to(torch_device)
diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py
index 1450a8c7710c..525124297345 100644
--- a/tests/models/t5/test_modeling_tf_t5.py
+++ b/tests/models/t5/test_modeling_tf_t5.py
@@ -16,7 +16,7 @@
import unittest
from transformers import T5Config, is_tf_available
-from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
+from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -227,23 +227,6 @@ def create_and_check_t5_decoder_model_past_large_inputs(
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
- def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
- config.eos_token_id = None
- config.max_length = 10
- config.do_sample = False
- config.num_beams = 1
- model = TFT5ForConditionalGeneration(config=config)
-
- # make sure there are no pad tokens in prompt
- input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5)
-
- generated = model.generate(input_ids)
-
- generate_xla = tf.function(model.generate, jit_compile=True)
- generated_xla = generate_xla(input_ids)
-
- self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
-
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
@@ -295,11 +278,14 @@ def test_t5_decoder_model_past_with_attn_mask(self):
def test_t5_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
- def test_t5_model_xla_generate_fast(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_t5_xla_generate_fast(*config_and_inputs)
+ # `create_and_check_t5_decoder_model_past_large_inputs` has special inputs:
+ # (config, input_ids, decoder_input_ids, attention_mask)
+ # and we have to prepare it correctly here.
+ config, input_ids, input_mask, token_labels = config_and_inputs
+ config_and_inputs = (config, input_ids, None, input_mask)
+
+ self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -319,8 +305,8 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None
+ @tooslow
def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
pass
@slow
@@ -587,6 +573,39 @@ def test_sample_generate(self):
self.assertListEqual(expected_output_string, output_strings)
+ @slow
+ def test_beam_search_xla_generate_simple(self):
+ model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
+
+ # tests XLA with task specific arguments
+ task_specific_config = getattr(model.config, "task_specific_params", {})
+ translation_config = task_specific_config.get("translation_en_to_fr", {})
+ model.config.update(translation_config)
+
+ # two examples with different lengths to confirm that attention masks are operational in XLA
+ sentences = [
+ model.config.prefix + "Today is a beautiful day.",
+ model.config.prefix + "I have four cats, three dogs, two birds, and a horse.",
+ ]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+
+ xla_generate = tf.function(model.generate, jit_compile=True)
+
+ output_ids = model.generate(input_ids, num_beams=2)
+ output_ids_xla = xla_generate(input_ids, num_beams=2)
+
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
+
+ expected_output_string = [
+ "Aujourd'hui est une belle journƩe.",
+ "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
+ ]
+
+ self.assertListEqual(expected_output_string, output_strings)
+ self.assertListEqual(expected_output_string, output_strings_xla)
+
@slow
def test_beam_search_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
@@ -640,9 +659,9 @@ def test_small_integration_test(self):
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
- mtf_score = -tf.math.reduce_sum(loss).numpy()
+ mtf_score = -tf.math.reduce_mean(loss).numpy()
- EXPECTED_SCORE = -19.0845
+ EXPECTED_SCORE = -4.771147
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
@@ -666,9 +685,9 @@ def test_small_v1_1_integration_test(self):
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
- mtf_score = -tf.math.reduce_sum(loss).numpy()
+ mtf_score = -tf.math.reduce_mean(loss).numpy()
- EXPECTED_SCORE = -59.0293
+ EXPECTED_SCORE = -14.757326
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
@@ -690,9 +709,9 @@ def test_small_byt5_integration_test(self):
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
- mtf_score = -tf.math.reduce_sum(loss).numpy()
+ mtf_score = -tf.math.reduce_mean(loss).numpy()
- EXPECTED_SCORE = -60.7397
+ EXPECTED_SCORE = -7.592465
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
@@ -700,19 +719,211 @@ def test_summarization(self):
model = self.model
tok = T5Tokenizer.from_pretrained("t5-base")
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
+ " implement a rigorous inspection regime .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
task_specific_config = getattr(model.config, "task_specific_params", {})
@@ -787,7 +998,10 @@ def test_translation_en_to_fr(self):
translation_config = task_specific_config.get("translation_en_to_fr", {})
model.config.update(translation_config)
- en_text = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '
+ en_text = (
+ ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
+ " countless generations of stars: the oldest stars are seen as blue dots. "
+ )
new_truncated_translation = (
"Cette section d'images provenant de l'enregistrement infrarouge effectuƩ par le tƩlescope Spitzer montre "
diff --git a/tests/models/t5/test_tokenization_t5.py b/tests/models/t5/test_tokenization_t5.py
index 1c0fde222cdb..28d85c77c97c 100644
--- a/tests/models/t5/test_tokenization_t5.py
+++ b/tests/models/t5/test_tokenization_t5.py
@@ -210,10 +210,9 @@ def test_max_length(self):
"Summary of the text.",
"Another summary.",
]
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(
- tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
- )
+ targets = tokenizer(
+ text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
+ )
self.assertEqual(32, targets["input_ids"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
@@ -235,12 +234,10 @@ def test_eos_in_input(self):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
- batch = tokenizer(src_text)
- with tokenizer.as_target_tokenizer():
- targets = tokenizer(tgt_text)
+ batch = tokenizer(src_text, text_target=tgt_text)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
- self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
+ self.assertEqual(expected_tgt_tokens, batch["labels"][0])
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]
diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py
index 31c9b38c8f86..b7b4af6e5a2a 100644
--- a/tests/models/tapas/test_modeling_tapas.py
+++ b/tests/models/tapas/test_modeling_tapas.py
@@ -32,7 +32,13 @@
is_torch_available,
)
from transformers.models.auto import get_values
-from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
+from transformers.testing_utils import (
+ require_scatter,
+ require_tensorflow_probability,
+ require_torch,
+ slow,
+ torch_device,
+)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -499,6 +505,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+ @require_tensorflow_probability
+ def test_pt_tf_model_equivalence(self):
+ super().test_pt_tf_model_equivalence()
+
def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on:
diff --git a/tests/models/tapas/test_modeling_tf_tapas.py b/tests/models/tapas/test_modeling_tf_tapas.py
index 147eb472702e..bf5e8be370c7 100644
--- a/tests/models/tapas/test_modeling_tf_tapas.py
+++ b/tests/models/tapas/test_modeling_tf_tapas.py
@@ -498,6 +498,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+ @unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
+ def test_dataset_conversion(self):
+ pass
+
def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on:
diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py
index 002f8c7e7549..f712f324f954 100644
--- a/tests/models/tapas/test_tokenization_tapas.py
+++ b/tests/models/tapas/test_tokenization_tapas.py
@@ -36,6 +36,7 @@
is_pt_tf_cross_test,
require_pandas,
require_scatter,
+ require_tensorflow_probability,
require_tokenizers,
require_torch,
slow,
@@ -141,6 +142,10 @@ def get_input_output_texts(self, tokenizer):
output_text = "unwanted, running"
return input_text, output_text
+ @require_tensorflow_probability
+ def test_tf_encode_plus_sent_to_model(self):
+ super().test_tf_encode_plus_sent_to_model()
+
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
@@ -251,7 +256,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/models/tapex/test_tokenization_tapex.py b/tests/models/tapex/test_tokenization_tapex.py
index c959b780215b..dec0f507ed3c 100644
--- a/tests/models/tapex/test_tokenization_tapex.py
+++ b/tests/models/tapex/test_tokenization_tapex.py
@@ -859,9 +859,8 @@ def test_tokenizer_as_target(self):
tokenizer = TapexTokenizer.from_pretrained("microsoft/tapex-base")
answer_text = "tapex is a good model!"
expected_src_tokens = [0, 90, 5776, 1178, 16, 10, 205, 1421, 328, 2]
- with tokenizer.as_target_tokenizer():
- answer_encoding = tokenizer(answer=answer_text)
- self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
+ answer_encoding = tokenizer(answer=answer_text)
+ self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
@slow
def test_tokenizer_lower_case(self):
@@ -870,23 +869,21 @@ def test_tokenizer_lower_case(self):
answer_text = "Beijing, London, Paris"
answer_text_lower = "beijing, london, paris"
- with cased_tokenizer.as_target_tokenizer():
- with uncased_tokenizer.as_target_tokenizer():
- self.assertNotEqual(
- cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
- )
- self.assertEqual(
- cased_tokenizer(answer=answer_text_lower).input_ids,
- uncased_tokenizer(answer=answer_text).input_ids,
- )
- # batched encoding assert
- self.assertNotEqual(
- cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
- )
- self.assertEqual(
- cased_tokenizer(answer=[answer_text_lower]).input_ids,
- uncased_tokenizer(answer=[answer_text]).input_ids,
- )
+ self.assertNotEqual(
+ cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
+ )
+ self.assertEqual(
+ cased_tokenizer(answer=answer_text_lower).input_ids,
+ uncased_tokenizer(answer=answer_text).input_ids,
+ )
+ # batched encoding assert
+ self.assertNotEqual(
+ cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
+ )
+ self.assertEqual(
+ cased_tokenizer(answer=[answer_text_lower]).input_ids,
+ uncased_tokenizer(answer=[answer_text]).input_ids,
+ )
# test input encoding lowercase
question = "Greece held its last Summer Olympics in 2004"
table_dict = {
diff --git a/tests/models/trajectory_transformer/__init__.py b/tests/models/trajectory_transformer/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py b/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py
new file mode 100644
index 000000000000..7cf5c741a1f6
--- /dev/null
+++ b/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch TrajectoryTransformer model. """
+
+
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import TrajectoryTransformerConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import TrajectoryTransformerModel
+ from transformers.models.trajectory_transformer.modeling_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ )
+
+
+class TrajectoryTransformerModelTester:
+ def __init__(self, parent, batch_size=13, n_embd=128, action_dim=6, observation_dim=17, is_training=True):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.n_embd = n_embd
+ self.action_dim = action_dim
+ self.observation_dim = observation_dim
+ self.is_training = is_training
+ self.seq_length = self.action_dim + self.observation_dim + 1
+
+ def prepare_config_and_inputs(self):
+ trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(self.batch_size)]).to(
+ torch_device
+ )
+ attention_mask = random_attention_mask((self.batch_size, self.seq_length)).to(torch_device)
+ targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(self.batch_size)]).to(
+ torch_device
+ )
+
+ config = self.get_config()
+ return config, trajectories, attention_mask, targets
+
+ def get_config(self):
+ return TrajectoryTransformerConfig(
+ batch_size=self.batch_size,
+ n_embd=self.n_embd,
+ action_dim=self.action_dim,
+ observation_dim=self.observation_dim,
+ )
+
+ def create_and_check_model(self, config, input_dict):
+ model = TrajectoryTransformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(trajectories=input_dict["trajectories"], attention_mask=input_dict["attention_mask"])
+ result = model(
+ trajectories=input_dict["trajectories"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ self.parent.assertEqual(result.hidden_states[-1].shape, (self.batch_size, self.seq_length, self.n_embd))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (config, trajectories, attention_mask, targets) = config_and_inputs
+ inputs_dict = {"trajectories": trajectories, "attention_mask": attention_mask, "targets": targets}
+ return config, inputs_dict
+
+
+@require_torch
+class TrajectoryTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (TrajectoryTransformerModel,) if is_torch_available() else ()
+
+ # Ignoring of a failing test from GenerationTesterMixin, as the model does not use inputs_ids
+ test_generate_without_input_ids = False
+
+ # Ignoring of a failing tests from ModelTesterMixin, as the model does not implement these features
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_attention_outputs = False
+ test_hidden_states_output = False
+ test_inputs_embeds = False
+ test_model_common_attributes = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = TrajectoryTransformerModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=TrajectoryTransformerConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_conditional_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["trajectories"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ # # Input is 'trajectories' not 'input_ids'
+ def test_model_main_input_name(self):
+ model_signature = inspect.signature(getattr(TrajectoryTransformerModel, "forward"))
+ # The main input is the name of the argument after `self`
+ observed_main_input_name = list(model_signature.parameters.keys())[1]
+ self.assertEqual(TrajectoryTransformerModel.main_input_name, observed_main_input_name)
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = self.has_attentions
+
+ model = TrajectoryTransformerModel(config)
+ model.to(torch_device)
+
+ outputs = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ output = outputs[0]
+ hidden_states = outputs.hidden_states[0]
+ hidden_states.retain_grad()
+
+ if self.has_attentions:
+ attentions = outputs.attentions[0]
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+
+ if self.has_attentions:
+ self.assertIsNotNone(attentions.grad)
+
+ def test_training(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ model = TrajectoryTransformerModel(config)
+ model.to(torch_device)
+ model.train()
+ loss = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ ).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ model = TrajectoryTransformerModel(config)
+ model.gradient_checkpointing_enable()
+ model.to(torch_device)
+ model.train()
+ loss = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=False,
+ return_dict=True,
+ ).loss
+ loss.backward()
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TrajectoryTransformerModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class TrajectoryTransformerModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_prediction(self):
+ batch_size = 1
+
+ config = TrajectoryTransformerConfig.from_pretrained("CarlCochet/trajectory-transformer-halfcheetah-medium-v2")
+ model = TrajectoryTransformerModel.from_pretrained(
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2", config=config
+ )
+ model.to(torch_device)
+ model.eval()
+
+ seq_length = model.config.action_dim + model.config.observation_dim + 1
+
+ trajectories = torch.LongTensor(
+ [[3, 19, 20, 22, 9, 7, 23, 10, 18, 14, 13, 4, 17, 11, 5, 6, 15, 21, 2, 8, 1, 0, 12, 16]]
+ ).to(torch_device)
+ outputs = model(
+ trajectories=trajectories,
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ output = outputs.logits
+
+ expected_shape = torch.Size((batch_size, seq_length, model.config.vocab_size + 1))
+ expected_slice = torch.tensor(
+ [[[-0.7193, -0.2532, -0.0898], [1.9429, 2.0434, 2.3975], [-3.3651, -2.8744, -2.4532]]]
+ ).to(torch_device)
+ output_slice = output[:, :3, :3]
+
+ self.assertEqual(output.shape, expected_shape)
+ self.assertTrue(torch.allclose(output_slice, expected_slice, atol=1e-4))
diff --git a/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py b/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
index 129b2ac4cf73..84e25d8716f5 100644
--- a/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
+++ b/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
@@ -216,6 +216,10 @@ def test_model_from_pretrained(self):
model = TFTransfoXLModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
+ def test_dataset_conversion(self):
+ pass
+
@require_tf
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/trocr/test_modeling_trocr.py b/tests/models/trocr/test_modeling_trocr.py
index 6d8ff0aa606e..0c5e6f7ae8f9 100644
--- a/tests/models/trocr/test_modeling_trocr.py
+++ b/tests/models/trocr/test_modeling_trocr.py
@@ -161,6 +161,7 @@ def prepare_config_and_inputs_for_common(self):
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
def setUp(self):
diff --git a/tests/models/van/test_modeling_van.py b/tests/models/van/test_modeling_van.py
index 3e5b7fb1dfc7..6b6a672b9b4f 100644
--- a/tests/models/van/test_modeling_van.py
+++ b/tests/models/van/test_modeling_van.py
@@ -144,6 +144,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="Van does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="Van does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/videomae/__init__.py b/tests/models/videomae/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/videomae/test_feature_extraction_videomae.py b/tests/models/videomae/test_feature_extraction_videomae.py
new file mode 100644
index 000000000000..cfe00f51e5e5
--- /dev/null
+++ b/tests/models/videomae/test_feature_extraction_videomae.py
@@ -0,0 +1,202 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_video_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import VideoMAEFeatureExtractor
+
+
+class VideoMAEFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ num_frames=10,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.num_frames = num_frames
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ }
+
+
+@require_torch
+@require_vision
+class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = VideoMAEFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = VideoMAEFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL videos
+ video_inputs = prepare_video_inputs(self.feature_extract_tester, equal_resolution=False)
+ for video in video_inputs:
+ self.assertIsInstance(video, list)
+ self.assertIsInstance(video[0], Image.Image)
+
+ # Test not batched input
+ encoded_videos = feature_extractor(video_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_videos = feature_extractor(video_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ video_inputs = prepare_video_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for video in video_inputs:
+ self.assertIsInstance(video, list)
+ self.assertIsInstance(video[0], np.ndarray)
+
+ # Test not batched input
+ encoded_videos = feature_extractor(video_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_videos = feature_extractor(video_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ video_inputs = prepare_video_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for video in video_inputs:
+ self.assertIsInstance(video, list)
+ self.assertIsInstance(video[0], torch.Tensor)
+
+ # Test not batched input
+ encoded_videos = feature_extractor(video_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_videos = feature_extractor(video_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_videos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_frames,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py
new file mode 100644
index 000000000000..adce62021c9d
--- /dev/null
+++ b/tests/models/videomae/test_modeling_videomae.py
@@ -0,0 +1,421 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch VideoMAE model. """
+
+
+import copy
+import inspect
+import unittest
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+from transformers import VideoMAEConfig
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import (
+ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
+ VideoMAEForPreTraining,
+ VideoMAEForVideoClassification,
+ VideoMAEModel,
+ )
+ from transformers.models.videomae.modeling_videomae import VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from transformers import VideoMAEFeatureExtractor
+
+
+class VideoMAEModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=10,
+ num_channels=3,
+ patch_size=2,
+ tubelet_size=2,
+ num_frames=2,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ mask_ratio=0.9,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.num_frames = num_frames
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.mask_ratio = mask_ratio
+ self.scope = scope
+
+ # in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame
+ self.num_patches_per_frame = (image_size // patch_size) ** 2
+ self.seq_length = (num_frames // tubelet_size) * self.num_patches_per_frame
+
+ # use this variable to define bool_masked_pos
+ self.num_masks = int(mask_ratio * self.seq_length)
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor(
+ [self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size]
+ )
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return VideoMAEConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ num_frames=self.num_frames,
+ tubelet_size=self.tubelet_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = VideoMAEModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_pretraining(self, config, pixel_values, labels):
+ model = VideoMAEForPreTraining(config)
+ model.to(torch_device)
+ model.eval()
+ # important: each video needs to have the same number of masked patches
+ # hence we define a single mask, which we then repeat for each example in the batch
+ mask = torch.ones((self.num_masks,))
+ mask = torch.cat([mask, torch.zeros(self.seq_length - mask.size(0))])
+ bool_masked_pos = mask.expand(self.batch_size, -1).bool()
+
+ result = model(pixel_values, bool_masked_pos)
+ # model only returns predictions for masked patches
+ num_masked_patches = mask.sum().item()
+ decoder_num_labels = 3 * self.tubelet_size * self.patch_size**2
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, num_masked_patches, decoder_num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class VideoMAEModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as VideoMAE does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (VideoMAEModel, VideoMAEForPreTraining, VideoMAEForVideoClassification) if is_torch_available() else ()
+ )
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = VideoMAEModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=VideoMAEConfig, has_text_modality=False, hidden_size=37)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = copy.deepcopy(inputs_dict)
+
+ if model_class == VideoMAEForPreTraining:
+ # important: each video needs to have the same number of masked patches
+ # hence we define a single mask, which we then repeat for each example in the batch
+ mask = torch.ones((self.model_tester.num_masks,))
+ mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
+ bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool()
+ inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)
+
+ if return_labels:
+ if model_class in [
+ *get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING),
+ ]:
+ inputs_dict["labels"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+
+ return inputs_dict
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="VideoMAE does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_pretraining(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = VideoMAEModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ num_visible_patches = self.model_tester.seq_length - self.model_tester.num_masks
+ seq_len = (
+ num_visible_patches if model_class == VideoMAEForPreTraining else self.model_tester.seq_length
+ )
+
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ self.assertEqual(out_len + 1, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+ expected_num_layers = self.model_tester.num_hidden_layers + 1
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ num_visible_patches = self.model_tester.seq_length - self.model_tester.num_masks
+ seq_length = num_visible_patches if model_class == VideoMAEForPreTraining else self.model_tester.seq_length
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+
+# We will verify our results on a video of eating spaghetti
+# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
+def prepare_video():
+ file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy")
+ video = np.load(file)
+ return list(video)
+
+
+@require_torch
+@require_vision
+class VideoMAEModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ # logits were tested with a different mean and std, so we use the same here
+ return (
+ VideoMAEFeatureExtractor(image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_for_video_classification(self):
+ model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics").to(
+ torch_device
+ )
+
+ feature_extractor = self.default_feature_extractor
+ video = prepare_video()
+ inputs = feature_extractor(video, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 400))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([0.3669, -0.0688, -0.2421]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+
+ @slow
+ def test_inference_for_pretraining(self):
+ model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base-short").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ video = prepare_video()
+ inputs = feature_extractor(video, return_tensors="pt").to(torch_device)
+
+ # add boolean mask, indicating which patches to mask
+ local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
+ inputs["bool_masked_pos"] = torch.load(local_path)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size([1, 1408, 1536])
+ expected_slice = torch.tensor(
+ [[0.7994, 0.9612, 0.8508], [0.7401, 0.8958, 0.8302], [0.5862, 0.7468, 0.7325]], device=torch_device
+ )
+ self.assertEqual(outputs.logits.shape, expected_shape)
+ self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))
+
+ # verify the loss (`config.norm_pix_loss` = `True`)
+ expected_loss = torch.tensor([0.5142], device=torch_device)
+ self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-4))
+
+ # verify the loss (`config.norm_pix_loss` = `False`)
+ model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base-short", norm_pix_loss=False).to(
+ torch_device
+ )
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ expected_loss = torch.tensor(torch.tensor([0.6469]), device=torch_device)
+ self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-4))
diff --git a/tests/models/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py
index 0c6783c439a3..82aa0767470e 100644
--- a/tests/models/vilt/test_modeling_vilt.py
+++ b/tests/models/vilt/test_modeling_vilt.py
@@ -37,6 +37,7 @@
ViltForImagesAndTextClassification,
ViltForMaskedLM,
ViltForQuestionAnswering,
+ ViltForTokenClassification,
ViltModel,
)
from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -173,6 +174,23 @@ def create_and_check_model(
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size)
)
+ def create_and_check_for_token_classification(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ pixel_values,
+ token_labels,
+ ):
+ model = ViltForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, pixel_values=pixel_values)
+ result = model(input_ids, token_type_ids=token_type_ids, pixel_values=pixel_values)
+ result = model(input_ids, pixel_values=pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -204,6 +222,7 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
ViltForQuestionAnswering,
ViltForImageAndTextRetrieval,
ViltForMaskedLM,
+ ViltForTokenClassification,
)
if is_torch_available()
else ()
@@ -216,15 +235,12 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
- # if model_class.__name__ == "ViltForNaturalLanguageVisualReasonining":
- # inputs_dict["pixel_values"] = floats_tensor([self.model_tester.batch_size, self.model_tester.num_images, self.model_tester.num_channels, self.model_tester.image_size, self.model_tester.image_size])
-
if return_labels:
if model_class.__name__ == "ViltForQuestionAnswering":
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, self.model_tester.num_labels, device=torch_device
)
- elif model_class.__name__ == "ViltForMaskedLM":
+ elif model_class.__name__ in ["ViltForMaskedLM", "ViltForTokenClassification"]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
@@ -246,6 +262,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
def test_training(self):
if not self.model_tester.is_training:
return
@@ -503,6 +523,10 @@ def setUp(self):
def test_model(self):
pass
+ @unittest.skip("We only test the model that takes in multiple images")
+ def test_for_token_classification(self):
+ pass
+
# We will verify our results on an image of cute cats
def prepare_img():
@@ -589,7 +613,10 @@ def test_inference_natural_language_visual_reasoning(self):
image1 = Image.open(dataset[0]["file"]).convert("RGB")
image2 = Image.open(dataset[1]["file"]).convert("RGB")
- text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ text = (
+ "The left image contains twice the number of dogs as the right image, and at least two dogs in total are"
+ " standing."
+ )
encoding_1 = processor(image1, text, return_tensors="pt")
encoding_2 = processor(image2, text, return_tensors="pt")
diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
index 9edbd3f802fb..97ac81390530 100644
--- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
@@ -31,6 +31,7 @@
slow,
torch_device,
)
+from transformers.utils.generic import ModelOutput
from ...test_modeling_tf_common import floats_tensor, ids_tensor
from ..gpt2.test_modeling_tf_gpt2 import TFGPT2ModelTester
@@ -314,31 +315,145 @@ def check_encoder_decoder_model_generate(self, pixel_values, config, decoder_con
tuple(generated_output.shape.as_list()), (pixel_values.shape[0],) + (decoder_config.max_length,)
)
- def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
+ """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
+
+ Args:
+ model_class: The class of the model that is currently testing. For example, `TFBertModel`,
+ TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
+ error messages.
+ name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
+ attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
+ being a named field in the output.
+ """
+
+ self.assertEqual(type(name), str)
+ if attributes is not None:
+ self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
+
+ # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
+ if isinstance(tf_outputs, ModelOutput):
+ self.assertTrue(
+ isinstance(pt_outputs, ModelOutput),
+ f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
+ )
- pt_model.to(torch_device)
- pt_model.eval()
+ tf_keys = [k for k, v in tf_outputs.items() if v is not None]
+ pt_keys = [k for k, v in pt_outputs.items() if v is not None]
+
+ self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
+
+ # convert to the case of `tuple`
+ # appending each key to the current (string) `names`
+ attributes = tuple([f"{name}.{k}" for k in tf_keys])
+ self.check_pt_tf_outputs(
+ tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
+ )
+
+ # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
+ elif type(tf_outputs) in [tuple, list]:
+ self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
+ self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
+
+ if attributes is not None:
+ # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
+ self.assertEqual(
+ len(attributes),
+ len(tf_outputs),
+ f"{name}: The tuple `names` should have the same length as `tf_outputs`",
+ )
+ else:
+ # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
+ attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
+
+ for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
+ self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
+
+ elif isinstance(tf_outputs, tf.Tensor):
+ self.assertTrue(
+ isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
+ )
+
+ tf_outputs = tf_outputs.numpy()
+ pt_outputs = pt_outputs.detach().to("cpu").numpy()
+
+ self.assertEqual(
+ tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
+ )
+
+ # deal with NumPy's scalars to make replacing nan values by 0 work.
+ if np.isscalar(tf_outputs):
+ tf_outputs = np.array([tf_outputs])
+ pt_outputs = np.array([pt_outputs])
+
+ tf_nans = np.isnan(tf_outputs)
+ pt_nans = np.isnan(pt_outputs)
+
+ pt_outputs[tf_nans] = 0
+ tf_outputs[tf_nans] = 0
+ pt_outputs[pt_nans] = 0
+ tf_outputs[pt_nans] = 0
+
+ max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
+ self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
+ else:
+ raise ValueError(
+ "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
+ )
+
+ def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
+
+ pt_inputs_dict = {}
+ for name, key in tf_inputs_dict.items():
+ if type(key) == bool:
+ pt_inputs_dict[name] = key
+ elif name == "input_values":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ elif name == "pixel_values":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ elif name == "input_features":
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ # other general float inputs
+ elif tf_inputs_dict[name].dtype.is_floating:
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
+ else:
+ pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
+
+ return pt_inputs_dict
- # prepare inputs
- tf_inputs = inputs_dict
- pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
- if "labels" in pt_inputs:
- pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
+ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
+
+ pt_inputs_dict = self.prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
- pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}
+ pt_inputs_dict = {
+ k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
+ }
+
+ # send pytorch model to the correct device
+ pt_model.to(torch_device)
+
+ # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
+ pt_model.eval()
with torch.no_grad():
- pt_outputs = pt_model(**pt_inputs).to_tuple()
+ pt_outputs = pt_model(**pt_inputs_dict)
+ tf_outputs = tf_model(tf_inputs_dict)
+
+ # tf models returned loss is usually a tensor rather than a scalar.
+ # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
+ # Change it here to a scalar to match PyTorch models' loss
+ tf_loss = getattr(tf_outputs, "loss", None)
+ if tf_loss is not None:
+ tf_outputs.loss = tf.math.reduce_mean(tf_loss)
+
+ self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
- tf_outputs = tf_model(**inputs_dict)
- if "loss" in tf_outputs:
- tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
- tf_outputs = tf_outputs.to_tuple()
- self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
+ def check_pt_tf_equivalence(self, tf_model, pt_model, tf_inputs_dict):
+ """Wrap `check_pt_tf_models` to further check PT -> TF again"""
- for tf_output, pt_output in zip(tf_outputs, pt_outputs):
- self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
+ self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@@ -351,18 +466,16 @@ def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
- tf_outputs_loaded = tf_model_loaded(**inputs_dict)
- if "loss" in tf_outputs_loaded:
- tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
- tf_outputs_loaded = tf_outputs_loaded.to_tuple()
- self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
+ self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
- for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
- self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
-
- def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
+ def check_pt_to_tf_equivalence(self, config, decoder_config, tf_inputs_dict):
+ """EncoderDecoderModel requires special way to cross load (PT -> TF)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
+ # Output all for aggressive testing
+ encoder_decoder_config.output_hidden_states = True
+ # All models tested in this file have attentions
+ encoder_decoder_config.output_attentions = True
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
@@ -376,11 +489,16 @@ def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
- self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
+ self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
- def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
+ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):
+ """EncoderDecoderModel requires special way to cross load (TF -> PT)"""
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
+ # Output all for aggressive testing
+ encoder_decoder_config.output_hidden_states = True
+ # TODO: A generalizable way to determine this attribute
+ encoder_decoder_config.output_attentions = True
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
@@ -389,7 +507,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
- _tf_model(**inputs_dict)
+ _tf_model(**tf_inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
@@ -398,6 +516,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
+ tf_model.config = encoder_decoder_config
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
@@ -409,7 +528,7 @@ def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
- self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
+ self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
def test_encoder_decoder_model(self):
config_inputs_dict = self.prepare_config_and_inputs()
@@ -448,7 +567,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
- def test_pt_tf_equivalence(self):
+ def test_pt_tf_model_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
@@ -467,48 +586,58 @@ def test_pt_tf_equivalence(self):
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
- inputs_dict = config_inputs_dict
- # `encoder_hidden_states` is not used in model call/forward
- del inputs_dict["encoder_hidden_states"]
-
- inputs_dict_with_labels = copy.copy(inputs_dict)
- inputs_dict_with_labels["labels"] = labels
+ # Output all for aggressive testing
+ config.output_hidden_states = True
+ decoder_config.output_hidden_states = True
+ # All models tested in this file have attentions
+ config.output_attentions = True
+ decoder_config.output_attentions = True
- # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
- batch_size = inputs_dict["decoder_attention_mask"].shape[0]
- inputs_dict["decoder_attention_mask"] = tf.constant(
- np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
- )
+ tf_inputs_dict = config_inputs_dict
+ # `encoder_hidden_states` is not used in model call/forward
+ del tf_inputs_dict["encoder_hidden_states"]
+
+ # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
+ # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
+ for k in ["decoder_attention_mask"]:
+ attention_mask = tf_inputs_dict[k]
+
+ # Make sure no all 0s attention masks - to avoid failure at this moment.
+ # Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
+ # TODO: remove this line once a fix regarding large negative values for attention mask is done.
+ attention_mask = tf.concat(
+ [tf.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], axis=-1
+ )
+ tf_inputs_dict[k] = attention_mask
- # TF models don't use the `use_cache` option and cache is not returned as a default.
- # So we disable `use_cache` here for PyTorch model.
- decoder_config.use_cache = False
+ tf_inputs_dict_with_labels = copy.copy(tf_inputs_dict)
+ tf_inputs_dict_with_labels["labels"] = labels
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
- # check without `enc_to_dec_proj` projection
+ # Original test: check without `labels` and without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
- self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
- self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
+ self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
+ self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
- # check equivalence with labels
- self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
- self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
+ # check with `labels`
+ self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
+ self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
- # # check `enc_to_dec_proj` work as expected
+ # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
- # self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
- # self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
+ # self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
+ # self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFVisionEncoderDecoderModel(encoder_decoder_config)
- model(**inputs_dict)
+ model(tf_inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
index f8ac8f1cdf1c..320cdd633062 100644
--- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
@@ -20,7 +20,7 @@
from datasets import load_dataset
from packaging import version
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import require_torch, require_vision, slow, to_2tuple, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
@@ -48,7 +48,6 @@
ViTModel,
)
from transformers.modeling_outputs import BaseModelOutput
- from transformers.models.vit.modeling_vit import to_2tuple
if is_vision_available():
diff --git a/tests/models/vit/test_modeling_flax_vit.py b/tests/models/vit/test_modeling_flax_vit.py
index 56fe28d41baf..611f93648854 100644
--- a/tests/models/vit/test_modeling_flax_vit.py
+++ b/tests/models/vit/test_modeling_flax_vit.py
@@ -91,8 +91,7 @@ def prepare_config_and_inputs(self):
return config, pixel_values
- def create_and_check_model(self, config, pixel_values, labels):
-
+ def create_and_check_model(self, config, pixel_values):
model = FlaxViTModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
@@ -101,6 +100,19 @@ def create_and_check_model(self, config, pixel_values, labels):
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+ def create_and_check_for_image_classification(self, config, pixel_values):
+ config.num_labels = self.type_sequence_label_size
+ model = FlaxViTForImageClassification(config=config)
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = FlaxViTForImageClassification(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -123,7 +135,15 @@ def setUp(self) -> None:
def test_config(self):
self.config_tester.run_common_tests()
- # We neeed to override this test because ViT's forward signature is different than text models.
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ # We need to override this test because ViT's forward signature is different than text models.
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/vit/test_modeling_tf_vit.py b/tests/models/vit/test_modeling_tf_vit.py
index 096558091ac8..7f452886f150 100644
--- a/tests/models/vit/test_modeling_tf_vit.py
+++ b/tests/models/vit/test_modeling_tf_vit.py
@@ -133,6 +133,13 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = TFViTForImageClassification(config)
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py
index a1379a9d31ec..5f856436f3c0 100644
--- a/tests/models/vit/test_modeling_vit.py
+++ b/tests/models/vit/test_modeling_vit.py
@@ -120,6 +120,25 @@ def create_and_check_model(self, config, pixel_values, labels):
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
+ model = ViTForMaskedImageModeling(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = ViTForMaskedImageModeling(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
+
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = ViTForImageClassification(config)
@@ -128,6 +147,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+ # test greyscale images
+ config.num_channels = 1
+ model = ViTForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -155,6 +184,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = False
@@ -196,6 +226,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_for_masked_image_modeling(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
+
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@@ -239,3 +273,30 @@ def test_inference_image_classification_head(self):
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+
+ @slow
+ def test_inference_interpolate_pos_encoding(self):
+ # ViT models have an `interpolate_pos_encoding` argument in their forward method,
+ # allowing to interpolate the pre-trained position embeddings in order to use
+ # the model on higher resolutions. The DINO model by Facebook AI leverages this
+ # to visualize self-attention on higher resolution images.
+ model = ViTModel.from_pretrained("facebook/dino-vits8").to(torch_device)
+
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", size=480)
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values, interpolate_pos_encoding=True)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 3601, 384))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py
index dc73b20e7eac..906c79e766f4 100644
--- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py
+++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py
@@ -38,7 +38,6 @@
import tensorflow as tf
from transformers import TFViTMAEForPreTraining, TFViTMAEModel
- from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple
if is_vision_available():
@@ -67,6 +66,7 @@ def __init__(
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
+ mask_ratio=0.6,
scope=None,
):
self.parent = parent
@@ -85,8 +85,14 @@ def __init__(
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
+ self.mask_ratio = mask_ratio
self.scope = scope
+ # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
+ # (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
+
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -107,34 +113,39 @@ def get_config(self):
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
+ decoder_hidden_size=self.hidden_size,
+ decoder_num_hidden_layers=self.num_hidden_layers,
+ decoder_num_attention_heads=self.num_attention_heads,
+ decoder_intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
+ mask_ratio=self.mask_ratio,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTMAEModel(config=config)
result = model(pixel_values, training=False)
- # expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
- # (we add 1 for the [CLS] token)
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
- self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = TFViTMAEForPreTraining(config)
result = model(pixel_values, training=False)
# expected sequence length = num_patches
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- expected_seq_len = num_patches
+ num_patches = (self.image_size // self.patch_size) ** 2
expected_num_channels = self.patch_size**2 * self.num_channels
- self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = TFViTMAEForPreTraining(config)
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values, training=False)
+ expected_num_channels = self.patch_size**2
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -166,7 +177,6 @@ def test_config(self):
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
def test_inputs_embeds(self):
- # ViTMAE does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -253,114 +263,6 @@ def prepare_numpy_arrays(inputs_dict):
output_for_kw_input = model(**inputs_np, noise=noise)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
- chunk_length = getattr(self.model_tester, "chunk_length", None)
- if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
- encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- if chunk_length is not None:
- self.assertListEqual(
- list(attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
-
- if hasattr(self.model_tester, "num_hidden_states_types"):
- added_hidden_states = self.model_tester.num_hidden_states_types
- elif self.is_encoder_decoder:
- added_hidden_states = 2
- else:
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- if chunk_length is not None:
- self.assertListEqual(
- list(self_attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
-
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # ViTMAE has a different seq_length
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
@@ -379,7 +281,7 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# overwrite from common since TFViTMAEForPretraining outputs loss along with
- # logits and mask indices. loss and mask indicies are not suitable for integration
+ # logits and mask indices. loss and mask indices are not suitable for integration
# with other keras modules.
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py
index 191984d82f55..5a48d253a385 100644
--- a/tests/models/vit_mae/test_modeling_vit_mae.py
+++ b/tests/models/vit_mae/test_modeling_vit_mae.py
@@ -35,7 +35,7 @@
from torch import nn
from transformers import ViTMAEForPreTraining, ViTMAEModel
- from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
+ from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -64,6 +64,7 @@ def __init__(
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
+ mask_ratio=0.6,
scope=None,
):
self.parent = parent
@@ -82,8 +83,14 @@ def __init__(
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
+ self.mask_ratio = mask_ratio
self.scope = scope
+ # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
+ # (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
+
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -109,6 +116,7 @@ def get_config(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
+ mask_ratio=self.mask_ratio,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -116,26 +124,26 @@ def create_and_check_model(self, config, pixel_values, labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
- # expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
- # (we add 1 for the [CLS] token)
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
- self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = ViTMAEForPreTraining(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
- # expected sequence length = num_patches
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- expected_seq_len = num_patches
+ num_patches = (self.image_size // self.patch_size) ** 2
expected_num_channels = self.patch_size**2 * self.num_channels
- self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = ViTMAEForPreTraining(config)
+ model.to(torch_device)
+ model.eval()
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ expected_num_channels = self.patch_size**2
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -165,8 +173,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="ViTMAE does not use inputs_embeds")
def test_inputs_embeds(self):
- # ViTMAE does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -198,126 +206,6 @@ def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
- chunk_length = getattr(self.model_tester, "chunk_length", None)
- if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
- encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- if chunk_length is not None:
- self.assertListEqual(
- list(attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- if hasattr(self.model_tester, "num_hidden_states_types"):
- added_hidden_states = self.model_tester.num_hidden_states_types
- elif self.is_encoder_decoder:
- added_hidden_states = 2
- else:
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- if chunk_length is not None:
- self.assertListEqual(
- list(self_attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # ViTMAE has a different seq_length
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
index a228ebfa1947..b74e271c02d6 100644
--- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
@@ -463,7 +463,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
index 3187303982e1..323f44ba99fb 100644
--- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
@@ -548,7 +548,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
index 98aebdd72818..21f77b19a553 100644
--- a/tests/models/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -1179,7 +1179,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@@ -1461,8 +1462,11 @@ def test_phoneme_recognition(self):
EXPECTED_TRANSCRIPTIONS = [
"É m Ʀ n s É d t É Ć° É j uĖ n ÉŖ v É s s É aÉŖ É É” z ÉŖ s t",
- "s w É t k Ź v É d b ɹ iĖ É n z b ÉĖ d i t ɹ ÉŖ k l ÉŖ Å ÉŖ n t É Ć° É t aÉŖ t l oÉŖ n k l ÉĖ Īø ư Ʀ w Ź z ư ÉŖ oŹ n l i É” ÉĖɹ m É n t h iĖ w ÉĖɹ",
- "ư É k aÉŖ t É n h ÉŖ z tŹ É s t s t ÉŖ l d ɹ ÉŖ p ÉŖ Å b l Ź d ư ÉŖ eÉŖ k Ź v h ÉŖ z oŹ v É s t ɹ eÉŖ n d aÉŖ z iĖ v É n ư É s ÉĖɹ ɹ ÉŖ Å É É¹ iĖ n É É É¹ aŹ n d h ÉŖ m w ÉŖ ư É Īø aŹ z É n d z Ź v s p É k t eÉŖ ɾ É z w ÉĖ t ɹ ÉŖ v ÉŖ Ʀ l įµ» ɾ i z n ÉĖ t w ÉĖ Īø Īø ÉŖ Å k ÉŖ Å É b aŹ t",
+ "s w É t k Ź v É d b ɹ iĖ É n z b ÉĖ d i t ɹ ÉŖ k l ÉŖ Å ÉŖ n t É Ć° É t aÉŖ t l oÉŖ n k l ÉĖ Īø ư Ʀ w Ź z ư ÉŖ oŹ"
+ " n l i É” ÉĖɹ m É n t h iĖ w ÉĖɹ",
+ "ư É k aÉŖ t É n h ÉŖ z tŹ É s t s t ÉŖ l d ɹ ÉŖ p ÉŖ Å b l Ź d ư ÉŖ eÉŖ k Ź v h ÉŖ z oŹ v É s t ɹ eÉŖ n d aÉŖ z iĖ"
+ " v É n ư É s ÉĖɹ ɹ ÉŖ Å É É¹ iĖ n É É É¹ aŹ n d h ÉŖ m w ÉŖ ư É Īø aŹ z É n d z Ź v s p É k t eÉŖ ɾ É z w ÉĖ t ɹ"
+ " ÉŖ v ÉŖ Ʀ l įµ» ɾ i z n ÉĖ t w ÉĖ Īø Īø ÉŖ Å k ÉŖ Å É b aŹ t",
"h ÉŖ z ÉŖ n s t É n t v p Ʀ n ÉŖ k w Ź z f ÉĖ l oŹ d b aÉŖ É s m ÉĖ l Ź ÉĖɹ p b l oŹ h aÉŖ É n h ÉŖ z tŹ É s t",
]
# should correspond to =>:
diff --git a/tests/models/wav2vec2/test_processor_wav2vec2.py b/tests/models/wav2vec2/test_processor_wav2vec2.py
index 8b7188f8ebc0..5f1c259061c4 100644
--- a/tests/models/wav2vec2/test_processor_wav2vec2.py
+++ b/tests/models/wav2vec2/test_processor_wav2vec2.py
@@ -118,8 +118,7 @@ def test_tokenizer(self):
input_str = "This is a test string"
- with processor.as_target_processor():
- encoded_processor = processor(input_str)
+ encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
diff --git a/tests/models/wav2vec2_conformer/__init__.py b/tests/models/wav2vec2_conformer/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
new file mode 100644
index 000000000000..cb2719a591b6
--- /dev/null
+++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
@@ -0,0 +1,939 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
+
+import math
+import unittest
+
+import numpy as np
+from datasets import load_dataset
+
+from transformers import Wav2Vec2ConformerConfig, is_torch_available
+from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ )
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ Wav2Vec2ConformerGumbelVectorQuantizer,
+ _compute_mask_indices,
+ _sample_negative_indices,
+ )
+
+
+class Wav2Vec2ConformerModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=1024, # speech is longer
+ is_training=False,
+ hidden_size=16,
+ feat_extract_norm="group",
+ feat_extract_dropout=0.0,
+ feat_extract_activation="gelu",
+ conv_dim=(32, 32, 32),
+ conv_stride=(4, 4, 4),
+ conv_kernel=(8, 8, 8),
+ conv_bias=False,
+ num_conv_pos_embeddings=16,
+ num_conv_pos_embedding_groups=2,
+ num_hidden_layers=4,
+ num_attention_heads=2,
+ hidden_dropout_prob=0.1,
+ intermediate_size=20,
+ layer_norm_eps=1e-5,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ mask_time_prob=0.5,
+ mask_time_length=2,
+ vocab_size=32,
+ do_stable_layer_norm=False,
+ num_adapter_layers=1,
+ adapter_stride=2,
+ tdnn_dim=(32, 32),
+ tdnn_kernel=(5, 3),
+ tdnn_dilation=(1, 2),
+ xvector_output_dim=32,
+ position_embeddings_type="relative",
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_dropout = feat_extract_dropout
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = conv_dim
+ self.conv_stride = conv_stride
+ self.conv_kernel = conv_kernel
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.intermediate_size = intermediate_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.do_stable_layer_norm = do_stable_layer_norm
+ self.num_adapter_layers = num_adapter_layers
+ self.adapter_stride = adapter_stride
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.scope = scope
+ self.tdnn_dim = tdnn_dim
+ self.tdnn_kernel = tdnn_kernel
+ self.tdnn_dilation = tdnn_dilation
+ self.xvector_output_dim = xvector_output_dim
+ self.position_embeddings_type = position_embeddings_type
+
+ output_seq_length = self.seq_length
+ for kernel, stride in zip(self.conv_kernel, self.conv_stride):
+ output_seq_length = (output_seq_length - (kernel - 1)) / stride
+ self.output_seq_length = int(math.ceil(output_seq_length))
+ self.encoder_seq_length = self.output_seq_length
+
+ self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
+
+ def prepare_config_and_inputs(self, position_embeddings_type="relative"):
+ input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config(position_embeddings_type=position_embeddings_type)
+
+ return config, input_values, attention_mask
+
+ def get_config(self, position_embeddings_type="relative"):
+ return Wav2Vec2ConformerConfig(
+ hidden_size=self.hidden_size,
+ feat_extract_norm=self.feat_extract_norm,
+ feat_extract_dropout=self.feat_extract_dropout,
+ feat_extract_activation=self.feat_extract_activation,
+ conv_dim=self.conv_dim,
+ conv_stride=self.conv_stride,
+ conv_kernel=self.conv_kernel,
+ conv_bias=self.conv_bias,
+ mask_time_prob=self.mask_time_prob,
+ mask_time_length=self.mask_time_length,
+ num_conv_pos_embeddings=self.num_conv_pos_embeddings,
+ num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ intermediate_size=self.intermediate_size,
+ layer_norm_eps=self.layer_norm_eps,
+ do_stable_layer_norm=self.do_stable_layer_norm,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ vocab_size=self.vocab_size,
+ num_adapter_layers=self.num_adapter_layers,
+ adapter_stride=self.adapter_stride,
+ tdnn_dim=self.tdnn_dim,
+ tdnn_kernel=self.tdnn_kernel,
+ tdnn_dilation=self.tdnn_dilation,
+ xvector_output_dim=self.xvector_output_dim,
+ position_embeddings_type=position_embeddings_type,
+ )
+
+ def create_and_check_model(self, config, input_values, attention_mask):
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter_for_ctc(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 2 * config.hidden_size
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
+ )
+
+ def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 8
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
+ )
+
+ def create_and_check_batch_inference(self, config, input_values, *args):
+ # test does not pass for models making use of `group_norm`
+ # check: https://github.com/pytorch/fairseq/issues/3227
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0.0
+
+ batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
+
+ for i in range(input_values.shape[0]):
+ input_slice = input_values[i : i + 1, : input_lengths[i]]
+ output = model(input_slice).last_hidden_state
+
+ batch_output = batch_outputs[i : i + 1, : output.shape[1]]
+ self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
+
+ def check_ctc_loss(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_seq_classifier_loss(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForSequenceClassification(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+ unmasked_loss = model(input_values, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(masked_loss, float))
+ self.parent.assertTrue(isinstance(unmasked_loss, float))
+ self.parent.assertTrue(masked_loss != unmasked_loss)
+
+ def check_ctc_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze feature encoder
+ model.freeze_feature_encoder()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lenghts are at least
+ # one shorter than logit lenghts to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_seq_classifier_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForSequenceClassification(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_xvector_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForXVector(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_values, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_values, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForXVector,
+ )
+ if is_torch_available()
+ else ()
+ )
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = Wav2Vec2ConformerModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Wav2Vec2ConformerConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_relative(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_rotary(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_no_rel_pos(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type=None)
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_adapter(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
+
+ def test_model_with_adapter_for_ctc(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
+
+ def test_model_with_adapter_proj_dim(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_seq_classifier_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_seq_classifier_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_training(*config_and_inputs)
+
+ def test_xvector_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_xvector_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # Wav2Vec2Conformer has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_values`
+ def test_forward_signature(self):
+ pass
+
+ # Wav2Vec2Conformer cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # Wav2Vec2Conformer has no inputs_embeds
+ # and thus the `get_input_embeddings` fn
+ # is not implemented
+ def test_model_common_attributes(self):
+ pass
+
+ @is_pt_flax_cross_test
+ # non-robust architecture does not exist in Flax
+ def test_equivalence_flax_to_pt(self):
+ pass
+
+ @is_pt_flax_cross_test
+ # non-robust architecture does not exist in Flax
+ def test_equivalence_pt_to_flax(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_values = inputs_dict["input_values"]
+
+ input_lengths = torch.tensor(
+ [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "pos_bias_v",
+ "pos_bias_u",
+ "pointwise_conv1",
+ "pointwise_conv2",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
+ module.pos_bias_u.data.fill_(3)
+ if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
+ module.pos_bias_v.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ def test_mask_feature_prob_ctc(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", mask_feature_prob=0.2, mask_feature_length=2
+ )
+ model.to(torch_device).train()
+ processor = Wav2Vec2Processor.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
+ )
+
+ batch_duration_in_seconds = [1, 3, 2, 6]
+ input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
+
+ batch = processor(
+ input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
+ )
+
+ logits = model(
+ input_values=batch["input_values"].to(torch_device),
+ attention_mask=batch["attention_mask"].to(torch_device),
+ ).logits
+
+ self.assertEqual(logits.shape, (4, 1498, 32))
+
+ def test_mask_time_prob_ctc(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", mask_time_prob=0.2, mask_time_length=2
+ )
+ model.to(torch_device).train()
+ processor = Wav2Vec2Processor.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
+ )
+
+ batch_duration_in_seconds = [1, 3, 2, 6]
+ input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
+
+ batch = processor(
+ input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
+ )
+
+ logits = model(
+ input_values=batch["input_values"].to(torch_device),
+ attention_mask=batch["attention_mask"].to(torch_device),
+ ).logits
+
+ self.assertEqual(logits.shape, (4, 1498, 32))
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = Wav2Vec2ConformerModel.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class Wav2Vec2ConformerUtilsTest(unittest.TestCase):
+ def test_compute_mask_indices(self):
+ batch_size = 4
+ sequence_length = 60
+ mask_prob = 0.5
+ mask_length = 1
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
+
+ def test_compute_mask_indices_low_prob(self):
+ # with these settings num_masked_spans=0.5, which means probabilistic rounding
+ # ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
+ # the other 5 out of 10, cases num_masked_spans=1
+ n_trials = 100
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ count_dimensions_masked = 0
+ count_dimensions_not_masked = 0
+
+ for _ in range(n_trials):
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ num_masks = torch.sum(mask).item()
+
+ if num_masks > 0:
+ count_dimensions_masked += 1
+ else:
+ count_dimensions_not_masked += 1
+
+ # as we test for at least 10 masked dimension and at least
+ # 10 non-masked dimension, this test could fail with probability:
+ # P(100 coin flips, at most 9 heads) = 1.66e-18
+ self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
+ self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
+
+ def test_compute_mask_indices_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ def test_compute_mask_indices_attn_mask_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ attention_mask[:2, sequence_length // 2 :] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
+ )
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
+
+ def test_compute_mask_indices_short_audio(self):
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ # force one example to be heavily padded
+ attention_mask[0, 5:] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
+ )
+
+ # make sure that non-padded examples cannot be padded
+ self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
+
+ def test_compute_perplexity(self):
+ probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
+
+ ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs)
+ self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
+
+ # mask half of the input
+ mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
+ mask[0] = 0
+
+ ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs, mask)
+ self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
+
+ def test_sample_negatives(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+ def test_sample_negatives_with_mask(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ # second half of last input tensor is padded
+ mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ mask[-1, sequence_length // 2 :] = 0
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # replace masked feature vectors with -100 to test that those are not sampled
+ features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices(
+ (batch_size, sequence_length), num_negatives, mask.cpu().numpy()
+ )
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+
+ self.assertTrue((negatives >= 0).all().item())
+
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+
+@require_torch
+@slow
+class Wav2Vec2ConformerModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
+ speech_samples = speech_samples[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_ctc_normal_batched_rel_pos(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large-960h-ft")
+ model.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-conformer-rel-pos-large-960h-ft", do_lower_case=True
+ )
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_values = inputs.input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight loincloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_normal_batched_rope(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rope-large-960h-ft")
+ model.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-conformer-rope-large-960h-ft", do_lower_case=True
+ )
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_values = inputs.input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_pretrained(self):
+ model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ model.to(torch_device)
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-conformer-rel-pos-large", return_attention_mask=True
+ )
+ input_speech = self._load_datasamples(2)
+
+ inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
+
+ batch_size = inputs_dict["input_values"].shape[0]
+ feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
+
+ features_shape = (batch_size, feature_seq_length)
+
+ torch.manual_seed(0)
+ mask_time_indices = _compute_mask_indices(
+ features_shape,
+ model.config.mask_time_prob,
+ model.config.mask_time_length,
+ min_masks=2,
+ )
+ mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(
+ inputs_dict.input_values.to(torch_device),
+ attention_mask=inputs_dict.attention_mask.to(torch_device),
+ mask_time_indices=mask_time_indices,
+ )
+
+ # compute cosine similarity
+ cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ # retrieve cosine sim of masked features
+ cosine_sim_masked = cosine_sim[mask_time_indices]
+
+ # ... now compare to randomly initialized model
+
+ config = Wav2Vec2ConformerConfig.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ model_rand = Wav2Vec2ConformerForPreTraining(config).to(torch_device).eval()
+
+ with torch.no_grad():
+ outputs_rand = model_rand(
+ inputs_dict.input_values.to(torch_device),
+ attention_mask=inputs_dict.attention_mask.to(torch_device),
+ mask_time_indices=mask_time_indices,
+ )
+
+ # compute cosine similarity
+ cosine_sim_rand = torch.cosine_similarity(
+ outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
+ )
+
+ # retrieve cosine sim of masked features
+ cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
+
+ # a pretrained wav2vec2_conformer model has learned to predict the quantized latent states
+ # => the cosine similarity between quantized states and predicted states > 0.5
+ # a random wav2vec2_conformer model has not learned to predict the quantized latent states
+ # => the cosine similarity between quantized states and predicted states is very likely < 0.1
+ self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
index f5b3eea926d8..d66a5923868d 100644
--- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
+++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
@@ -164,8 +164,7 @@ def test_tokenizer(self):
input_str = "This is a test string"
- with processor.as_target_processor():
- encoded_processor = processor(input_str)
+ encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py
index 37301a79eda1..f4da4994266d 100644
--- a/tests/models/xglm/test_modeling_xglm.py
+++ b/tests/models/xglm/test_modeling_xglm.py
@@ -13,17 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import math
+import os
+import pickle
+import tempfile
import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
if is_torch_available():
@@ -31,6 +40,9 @@
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
class XGLMModelTester:
def __init__(
@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_missing_keys = False
test_pruning = False
@@ -337,6 +350,112 @@ def test_xglm_weight_initialization(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "input_features",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "token_type_ids",
+ "pixel_values",
+ "bbox",
+ "input_features",
+ ]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
@slow
def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
diff --git a/tests/models/xglm/test_tokenization_xglm.py b/tests/models/xglm/test_tokenization_xglm.py
index dd5c9f5e6a0c..05259ffaf9a3 100644
--- a/tests/models/xglm/test_tokenization_xglm.py
+++ b/tests/models/xglm/test_tokenization_xglm.py
@@ -179,7 +179,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth"
+ )
# fmt: off
original_tokenizer_encodings = [2, 1018, 67, 11, 1988, 2617, 5631, 278, 11, 3407, 48, 71630, 28085, 4, 3234, 157, 13, 6, 5, 6, 4, 3526, 768, 15, 659, 57, 298, 3983, 864, 129, 21, 6, 5, 13675, 377, 652, 7580, 10341, 155, 2817, 422, 1666, 7, 1674, 53, 113, 202277, 17892, 33, 60, 87, 4, 3234, 157, 61, 2667, 52376, 19, 88, 23, 735]
# fmt: on
diff --git a/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py b/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
index 51e8502b9bd5..5dec186bc7b9 100644
--- a/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
+++ b/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
@@ -102,8 +102,18 @@ def test_xprophetnet_ntg_inference(self):
tokenizer = XLMProphetNetTokenizer.from_pretrained("microsoft/xprophetnet-large-wiki100-cased-xglue-ntg")
- EN_SENTENCE = "Microsoft Corporation intends to officially end free support for the Windows 7 operating system after January 14, 2020, according to the official portal of the organization. From that day, users of this system will not be able to receive security updates, which could make their computers vulnerable to cyber attacks."
- RU_SENTENCE = "Š¾ŃŠæŠ¾ŃŠ°ŃŠøŃ Microsoft Š½Š°Š¼ŠµŃŠµŠ½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾ ŠæŃŠµŠŗŃŠ°ŃŠøŃŃ Š±ŠµŃŠæŠ»Š°ŃŠ½ŃŃ ŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ Š¾ŠæŠµŃŠ°Ńионной ŃŠøŃŃŠµŠ¼Ń Windows 7 ŠæŠ¾ŃŠ»Šµ 14 ŃŠ½Š²Š°ŃŃ 2020 гоГа, ŃŠ¾Š¾Š±ŃаеŃŃŃ Š½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾Š¼ поŃŃŠ°Š»Šµ Š¾ŃŠ³Š°Š½ŠøŠ·Š°ŃŠøŠø . Š” ŃŠŗŠ°Š·Š°Š½Š½Š¾Š³Š¾ Š“Š½Ń ŠæŠ¾Š»ŃŠ·Š¾Š²Š°Ńели ŃŃŠ¾Š¹ ŃŠøŃŃŠµŠ¼Ń не ŃŠ¼Š¾Š³ŃŃ ŠæŠ¾Š»ŃŃŠ°ŃŃ Š¾Š±Š½Š¾Š²Š»ŠµŠ½ŠøŃ Š±ŠµŠ·Š¾ŠæŠ°ŃŠ½Š¾ŃŃŠø, ŠøŠ·-за ŃŠµŠ³Š¾ ŠøŃ
компŃŃŃŠµŃŃ Š¼Š¾Š³ŃŃ ŃŃŠ°ŃŃ ŃŃŠ·Š²ŠøŠ¼Ńми Šŗ ŠŗŠøŠ±ŠµŃŠ°Ńакам."
+ EN_SENTENCE = (
+ "Microsoft Corporation intends to officially end free support for the Windows 7 operating system after"
+ " January 14, 2020, according to the official portal of the organization. From that day, users of this"
+ " system will not be able to receive security updates, which could make their computers vulnerable to"
+ " cyber attacks."
+ )
+ RU_SENTENCE = (
+ "Š¾ŃŠæŠ¾ŃŠ°ŃŠøŃ Microsoft Š½Š°Š¼ŠµŃŠµŠ½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾ ŠæŃŠµŠŗŃŠ°ŃŠøŃŃ Š±ŠµŃŠæŠ»Š°ŃŠ½ŃŃ ŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ Š¾ŠæŠµŃŠ°Ńионной ŃŠøŃŃŠµŠ¼Ń Windows 7"
+ " ŠæŠ¾ŃŠ»Šµ 14 ŃŠ½Š²Š°ŃŃ 2020 гоГа, ŃŠ¾Š¾Š±ŃаеŃŃŃ Š½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾Š¼ поŃŃŠ°Š»Šµ Š¾ŃŠ³Š°Š½ŠøŠ·Š°ŃŠøŠø . Š” ŃŠŗŠ°Š·Š°Š½Š½Š¾Š³Š¾ Š“Š½Ń ŠæŠ¾Š»ŃŠ·Š¾Š²Š°Ńели"
+ " ŃŃŠ¾Š¹ ŃŠøŃŃŠµŠ¼Ń не ŃŠ¼Š¾Š³ŃŃ ŠæŠ¾Š»ŃŃŠ°ŃŃ Š¾Š±Š½Š¾Š²Š»ŠµŠ½ŠøŃ Š±ŠµŠ·Š¾ŠæŠ°ŃŠ½Š¾ŃŃŠø, ŠøŠ·-за ŃŠµŠ³Š¾ ŠøŃ
компŃŃŃŠµŃŃ Š¼Š¾Š³ŃŃ ŃŃŠ°ŃŃ ŃŃŠ·Š²ŠøŠ¼Ńми"
+ " Šŗ ŠŗŠøŠ±ŠµŃŠ°Ńакам."
+ )
ZH_SENTENCE = (
"ę ¹ę®čÆ„ē»ē»ēå®ę¹éØę·ē½ē«ļ¼å¾®č½Æå
¬åøęē®åØ2020幓1ę14ę„ä¹åę£å¼ē»ę¢åƹWindows 7ęä½ē³»ē»ēå
蓹ęÆęćä»é£ę¶čµ·ļ¼čÆ„ē³»ē»ēēØę·å°ę ę³ę„ę¶å®å
Øę“ę°ļ¼čæåÆč½ä¼ä½æä»ä»¬ēč®”ē®ęŗå®¹ęåå°ē½ē»ę»å»ć"
)
@@ -132,8 +142,9 @@ def test_xprophetnet_ntg_inference(self):
tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True) for g in summary_ids_beam1
]
EXPECTED_TITLE_EN_BEAM1_TOK = "āMicrosoft āto āend āfree āsupport āfor āWindows ā7".split(" ")
- EXPECTED_TITLE_RU_BEAM1_TOK = "āMicrosoft āŠ½Š°Š¼ŠµŃŠµŠ½ а āŠæŃŠµŠŗŃŠ°ŃŠø ŃŃ āŠ±ŠµŃ ŠæŠ»Š°Ń Š½ŃŃ āŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ āWindows ā7 āŠæŠ¾ŃŠ»Šµ ā14 āŃŠ½Š²Š°ŃŃ ā2020 āŠ³Š¾Š“Š°".split(
- " "
+ EXPECTED_TITLE_RU_BEAM1_TOK = (
+ "āMicrosoft āŠ½Š°Š¼ŠµŃŠµŠ½ а āŠæŃŠµŠŗŃŠ°ŃŠø ŃŃ āŠ±ŠµŃ ŠæŠ»Š°Ń Š½ŃŃ āŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ āWindows ā7 āŠæŠ¾ŃŠ»Šµ ā14 āŃŠ½Š²Š°ŃŃ ā2020 āŠ³Š¾Š“Š°"
+ .split(" ")
)
EXPECTED_TITLE_ZH_BEAM1_TOK = "微软 å
¬åø ęē® ē»ę¢ 对 Windows ā7 ęä½ ē³»ē»ē å
蓹 ęÆę".split(" ")
self.assertListEqual(
diff --git a/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py b/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
index 53c5987fb2fb..c8f934b258b9 100644
--- a/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
+++ b/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
@@ -256,7 +256,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
0,
3293,
diff --git a/tests/models/xlnet/test_modeling_tf_xlnet.py b/tests/models/xlnet/test_modeling_tf_xlnet.py
index dc1ca077952c..bc8f31006bd4 100644
--- a/tests/models/xlnet/test_modeling_tf_xlnet.py
+++ b/tests/models/xlnet/test_modeling_tf_xlnet.py
@@ -403,7 +403,7 @@ def test_loss_computation(self):
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
- loss_size = tf.size(added_label)
+ expected_loss_size = added_label.shape.as_list()[:1]
# `TFXLNetLMHeadModel` doesn't cut logits/labels
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
@@ -417,12 +417,12 @@ def test_loss_computation(self):
input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -453,7 +453,7 @@ def test_loss_computation(self):
# Send to model
loss = model(tuple_input[:-1])[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
@require_tf
diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py
index 2c26315ceb96..dca727b29942 100644
--- a/tests/models/xlnet/test_modeling_xlnet.py
+++ b/tests/models/xlnet/test_modeling_xlnet.py
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
+ fx_compatible = False
test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py
index 75d399eaa797..1d07e50ce7b2 100644
--- a/tests/models/yolos/test_modeling_yolos.py
+++ b/tests/models/yolos/test_modeling_yolos.py
@@ -31,7 +31,7 @@
from torch import nn
from transformers import YolosForObjectDetection, YolosModel
- from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
+ from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -86,9 +86,7 @@ def __init__(
self.num_detection_tokens = num_detection_tokens
# we set the expected sequence length (which is used in several tests)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
- image_size = to_2tuple(self.image_size)
- patch_size = to_2tuple(self.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
def prepare_config_and_inputs(self):
diff --git a/tests/models/yoso/test_modeling_yoso.py b/tests/models/yoso/test_modeling_yoso.py
index d71b051d0a22..0a0749dd7d9b 100644
--- a/tests/models/yoso/test_modeling_yoso.py
+++ b/tests/models/yoso/test_modeling_yoso.py
@@ -126,6 +126,11 @@ def get_config(self):
initializer_range=self.initializer_range,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py
index 4ecfc917d56b..98ab0fad131e 100644
--- a/tests/onnx/test_onnx_v2.py
+++ b/tests/onnx/test_onnx_v2.py
@@ -1,3 +1,4 @@
+import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
@@ -6,7 +7,7 @@
import pytest
from parameterized import parameterized
-from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
+from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
@@ -15,13 +16,22 @@
export,
validate_model_outputs,
)
-from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
+from transformers.onnx.utils import (
+ compute_effective_axis_dimension,
+ compute_serialized_parameters_size,
+ get_preprocessor,
+)
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager
+if is_torch_available():
+ import torch
+
+ from transformers.models.deberta import modeling_deberta
+
@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
@@ -176,19 +186,35 @@ def test_values_override(self):
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
+ ("codegen", "Salesforce/codegen-350M-multi"),
+ ("deberta", "microsoft/deberta-base"),
+ ("deberta-v2", "microsoft/deberta-v2-xlarge"),
+ ("convnext", "facebook/convnext-tiny-224"),
+ ("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
+ ("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
+ ("squeezebert", "squeezebert/squeezebert-uncased"),
+ ("mobilebert", "google/mobilebert-uncased"),
+ ("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
+ ("layoutlmv3", "microsoft/layoutlmv3-base"),
+ ("levit", "facebook/levit-128S"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
+ ("data2vec-vision", "facebook/data2vec-vision-base"),
+ ("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
+ ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
+ ("yolos", "hustvl/yolos-tiny"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
+ ("bloom", "bigscience/bloom-350m"),
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
@@ -198,9 +224,15 @@ def test_values_override(self):
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
+ ("mt5", "google/mt5-base"),
("m2m-100", "facebook/m2m100_418M"),
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
+ ("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
+ ("longt5", "google/long-t5-local-base"),
+ # Disable for now as it causes fatal error `Floating point exception (core dumped)` and the subsequential tests are
+ # not run.
+ # ("longt5", "google/long-t5-tglobal-base"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
@@ -222,10 +254,15 @@ def test_values_override(self):
def _get_models_to_test(export_models_list):
models_to_test = []
if is_torch_available() or is_tf_available():
- for (name, model) in export_models_list:
- for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
- name
- ).items():
+ for name, model, *features in export_models_list:
+ if features:
+ feature_config_mapping = {
+ feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
+ }
+ else:
+ feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
+
+ for feature, onnx_config_class_constructor in feature_config_mapping.items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
else:
@@ -240,7 +277,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""
- def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
+ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
@@ -253,24 +290,20 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
if torch_version < onnx_config.torch_onnx_minimum_version:
pytest.skip(
- f"Skipping due to incompatible PyTorch version. Minimum required is {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
+ "Skipping due to incompatible PyTorch version. Minimum required is"
+ f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
- # Check the modality of the inputs and instantiate the appropriate preprocessor
- if model.main_input_name == "input_ids":
- preprocessor = AutoTokenizer.from_pretrained(model_name)
- # Useful for causal lm models that do not use pad tokens.
- if not getattr(config, "pad_token_id", None):
- config.pad_token_id = preprocessor.eos_token_id
- elif model.main_input_name == "pixel_values":
- preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
- else:
- raise ValueError(f"Unsupported model input name: {model.main_input_name}")
+ preprocessor = get_preprocessor(model_name)
+
+ # Useful for causal lm models that do not use pad tokens.
+ if isinstance(preprocessor, PreTrainedTokenizerBase) and not getattr(config, "pad_token_id", None):
+ config.pad_token_id = preprocessor.eos_token_id
with NamedTemporaryFile("w") as output:
try:
onnx_inputs, onnx_outputs = export(
- preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
+ preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
)
validate_model_outputs(
onnx_config,
@@ -291,6 +324,14 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
+ @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
+ @slow
+ @require_torch
+ @require_vision
+ @require_rjieba
+ def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
+ self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
+
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch
@@ -325,3 +366,40 @@ def test_tensorflow_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
+
+
+class StableDropoutTestCase(TestCase):
+ """Tests export of StableDropout module."""
+
+ @require_torch
+ @pytest.mark.filterwarnings("ignore:.*Dropout.*:UserWarning:torch.onnx.*") # torch.onnx is spammy.
+ def test_training(self):
+ """Tests export of StableDropout in training mode."""
+ devnull = open(os.devnull, "wb")
+ # drop_prob must be > 0 for the test to be meaningful
+ sd = modeling_deberta.StableDropout(0.1)
+ # Avoid warnings in training mode
+ do_constant_folding = False
+ # Dropout is a no-op in inference mode
+ training = torch.onnx.TrainingMode.PRESERVE
+ input = (torch.randn(2, 2),)
+
+ torch.onnx.export(
+ sd,
+ input,
+ devnull,
+ opset_version=12, # Minimum supported
+ do_constant_folding=do_constant_folding,
+ training=training,
+ )
+
+ # Expected to fail with opset_version < 12
+ with self.assertRaises(Exception):
+ torch.onnx.export(
+ sd,
+ input,
+ devnull,
+ opset_version=11,
+ do_constant_folding=do_constant_folding,
+ training=training,
+ )
diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py
index ec54055d7d62..0523639cc4fe 100644
--- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py
+++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py
@@ -141,15 +141,8 @@ def test_small_model_pt(self):
@require_torch
def test_small_model_pt_seq2seq(self):
- model_id = "hf-internal-testing/tiny-random-speech-encoder-decoder"
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
-
speech_recognizer = pipeline(
- task="automatic-speech-recognition",
- model=model_id,
- tokenizer=tokenizer,
- feature_extractor=feature_extractor,
+ model="hf-internal-testing/tiny-random-speech-encoder-decoder",
framework="pt",
)
@@ -184,7 +177,9 @@ def test_large_model_pt_with_lm(self):
self.assertEqual(
output,
{
- "text": "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajre"
+ "text": (
+ "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajre"
+ )
},
)
@@ -194,7 +189,9 @@ def test_large_model_pt_with_lm(self):
self.assertEqual(
output,
{
- "text": "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajcri",
+ "text": (
+ "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajcri"
+ ),
"chunks": [
{"text": "y", "timestamp": (0.52, 0.54)},
{"text": "en", "timestamp": (0.6, 0.68)},
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 818191b72518..5d5c8fa2333e 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -15,27 +15,56 @@
import copy
import importlib
import logging
+import os
import random
import string
+import sys
+import tempfile
import unittest
from abc import abstractmethod
from functools import lru_cache
+from pathlib import Path
from unittest import skipIf
+import numpy as np
+
+from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
+from requests.exceptions import HTTPError
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
+ AutoModelForSequenceClassification,
AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
+ TFAutoModelForSequenceClassification,
pipeline,
)
-from transformers.pipelines import get_task
-from transformers.pipelines.base import _pad
-from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
+from transformers.pipelines import PIPELINE_REGISTRY, get_task
+from transformers.pipelines.base import Pipeline, _pad
+from transformers.testing_utils import (
+ TOKEN,
+ USER,
+ CaptureLogger,
+ is_pipeline_test,
+ is_staging_test,
+ nested_simplify,
+ require_scatter,
+ require_tensorflow_probability,
+ require_tf,
+ require_torch,
+ slow,
+)
+from transformers.utils import is_tf_available, is_torch_available
+from transformers.utils import logging as transformers_logging
+
+
+sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
+
+from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
logger = logging.getLogger(__name__)
@@ -184,7 +213,8 @@ def test(self):
if tokenizer is None and feature_extractor is None:
self.skipTest(
- f"Ignoring {ModelClass}, cannot create a tokenizer or feature_extractor (PerceiverConfig with no FastTokenizer ?)"
+ f"Ignoring {ModelClass}, cannot create a tokenizer or feature_extractor (PerceiverConfig with"
+ " no FastTokenizer ?)"
)
pipeline, examples = self.get_test_pipeline(model, tokenizer, feature_extractor)
if pipeline is None:
@@ -460,8 +490,8 @@ def test_pipeline_offset_mapping(self):
@is_pipeline_test
-@require_torch
class PipelineUtilsTest(unittest.TestCase):
+ @require_torch
def test_pipeline_dataset(self):
from transformers.pipelines.pt_utils import PipelineDataset
@@ -475,6 +505,7 @@ def add(number, extra=0):
outputs = [dataset[i] for i in range(4)]
self.assertEqual(outputs, [2, 3, 4, 5])
+ @require_torch
def test_pipeline_iterator(self):
from transformers.pipelines.pt_utils import PipelineIterator
@@ -489,6 +520,7 @@ def add(number, extra=0):
outputs = [item for item in dataset]
self.assertEqual(outputs, [2, 3, 4, 5])
+ @require_torch
def test_pipeline_iterator_no_len(self):
from transformers.pipelines.pt_utils import PipelineIterator
@@ -506,6 +538,7 @@ def add(number, extra=0):
outputs = [item for item in dataset]
self.assertEqual(outputs, [2, 3, 4, 5])
+ @require_torch
def test_pipeline_batch_unbatch_iterator(self):
from transformers.pipelines.pt_utils import PipelineIterator
@@ -519,6 +552,7 @@ def add(number, extra=0):
outputs = [item for item in dataset]
self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
+ @require_torch
def test_pipeline_batch_unbatch_iterator_tensors(self):
import torch
@@ -536,6 +570,7 @@ def add(number, extra=0):
nested_simplify(outputs), [{"id": [[12, 22]]}, {"id": [[2, 3]]}, {"id": [[2, 4]]}, {"id": [[5]]}]
)
+ @require_torch
def test_pipeline_chunk_iterator(self):
from transformers.pipelines.pt_utils import PipelineChunkIterator
@@ -551,6 +586,7 @@ def preprocess_chunk(n: int):
self.assertEqual(outputs, [0, 1, 0, 1, 2])
+ @require_torch
def test_pipeline_pack_iterator(self):
from transformers.pipelines.pt_utils import PipelinePackIterator
@@ -583,6 +619,7 @@ def pack(item):
],
)
+ @require_torch
def test_pipeline_pack_unbatch_iterator(self):
from transformers.pipelines.pt_utils import PipelinePackIterator
@@ -606,3 +643,321 @@ def add(number, extra=0):
outputs = [item for item in dataset]
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
+
+ @slow
+ @require_torch
+ def test_load_default_pipelines_pt(self):
+ import torch
+
+ from transformers.pipelines import SUPPORTED_TASKS
+
+ set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
+ for task in SUPPORTED_TASKS.keys():
+ if task == "table-question-answering":
+ # test table in seperate test due to more dependencies
+ continue
+
+ self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
+
+ @slow
+ @require_tf
+ def test_load_default_pipelines_tf(self):
+ import tensorflow as tf
+
+ from transformers.pipelines import SUPPORTED_TASKS
+
+ set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
+ for task in SUPPORTED_TASKS.keys():
+ if task == "table-question-answering":
+ # test table in seperate test due to more dependencies
+ continue
+
+ self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
+
+ @slow
+ @require_torch
+ @require_scatter
+ def test_load_default_pipelines_pt_table_qa(self):
+ import torch
+
+ set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
+ self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
+
+ @slow
+ @require_tf
+ @require_tensorflow_probability
+ def test_load_default_pipelines_tf_table_qa(self):
+ import tensorflow as tf
+
+ set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
+ self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
+
+ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
+ from transformers.pipelines import SUPPORTED_TASKS, pipeline
+
+ task_dict = SUPPORTED_TASKS[task]
+ # test to compare pipeline to manually loading the respective model
+ model = None
+ relevant_auto_classes = task_dict[framework]
+
+ if len(relevant_auto_classes) == 0:
+ # task has no default
+ logger.debug(f"{task} in {framework} has no default")
+ return
+
+ # by default use first class
+ auto_model_cls = relevant_auto_classes[0]
+
+ # retrieve correct model ids
+ if task == "translation":
+ # special case for translation pipeline which has multiple languages
+ model_ids = []
+ revisions = []
+ tasks = []
+ for translation_pair in task_dict["default"].keys():
+ model_id, revision = task_dict["default"][translation_pair]["model"][framework]
+
+ model_ids.append(model_id)
+ revisions.append(revision)
+ tasks.append(task + f"_{'_to_'.join(translation_pair)}")
+ else:
+ # normal case - non-translation pipeline
+ model_id, revision = task_dict["default"]["model"][framework]
+
+ model_ids = [model_id]
+ revisions = [revision]
+ tasks = [task]
+
+ # check for equality
+ for model_id, revision, task in zip(model_ids, revisions, tasks):
+ # load default model
+ try:
+ set_seed_fn()
+ model = auto_model_cls.from_pretrained(model_id, revision=revision)
+ except ValueError:
+ # first auto class is possible not compatible with model, go to next model class
+ auto_model_cls = relevant_auto_classes[1]
+ set_seed_fn()
+ model = auto_model_cls.from_pretrained(model_id, revision=revision)
+
+ # load default pipeline
+ set_seed_fn()
+ default_pipeline = pipeline(task, framework=framework)
+
+ # compare pipeline model with default model
+ models_are_equal = check_models_equal_fn(default_pipeline.model, model)
+ self.assertTrue(models_are_equal, f"{task} model doesn't match pipeline.")
+
+ logger.debug(f"{task} in {framework} succeeded with {model_id}.")
+
+ def check_models_equal_pt(self, model1, model2):
+ models_are_equal = True
+ for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
+ if model1_p.data.ne(model2_p.data).sum() > 0:
+ models_are_equal = False
+
+ return models_are_equal
+
+ def check_models_equal_tf(self, model1, model2):
+ models_are_equal = True
+ for model1_p, model2_p in zip(model1.weights, model2.weights):
+ if np.abs(model1_p.numpy() - model2_p.numpy()).sum() > 1e-5:
+ models_are_equal = False
+
+ return models_are_equal
+
+
+class CustomPipeline(Pipeline):
+ def _sanitize_parameters(self, **kwargs):
+ preprocess_kwargs = {}
+ if "maybe_arg" in kwargs:
+ preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
+ return preprocess_kwargs, {}, {}
+
+ def preprocess(self, text, maybe_arg=2):
+ input_ids = self.tokenizer(text, return_tensors="pt")
+ return input_ids
+
+ def _forward(self, model_inputs):
+ outputs = self.model(**model_inputs)
+ return outputs
+
+ def postprocess(self, model_outputs):
+ return model_outputs["logits"].softmax(-1).numpy()
+
+
+@is_pipeline_test
+class CustomPipelineTest(unittest.TestCase):
+ def test_warning_logs(self):
+ transformers_logging.set_verbosity_debug()
+ logger_ = transformers_logging.get_logger("transformers.pipelines.base")
+
+ alias = "text-classification"
+ # Get the original task, so we can restore it at the end.
+ # (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
+ _, original_task, _ = PIPELINE_REGISTRY.check_task(alias)
+
+ try:
+ with CaptureLogger(logger_) as cm:
+ PIPELINE_REGISTRY.register_pipeline(alias, PairClassificationPipeline)
+ self.assertIn(f"{alias} is already registered", cm.out)
+ finally:
+ # restore
+ PIPELINE_REGISTRY.supported_tasks[alias] = original_task
+
+ def test_register_pipeline(self):
+ PIPELINE_REGISTRY.register_pipeline(
+ "custom-text-classification",
+ pipeline_class=PairClassificationPipeline,
+ pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
+ tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
+ default={"pt": "hf-internal-testing/tiny-random-distilbert"},
+ type="text",
+ )
+ assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
+
+ _, task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
+ self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
+ self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
+ self.assertEqual(task_def["type"], "text")
+ self.assertEqual(task_def["impl"], PairClassificationPipeline)
+ self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
+
+ # Clean registry for next tests.
+ del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
+
+ def test_dynamic_pipeline(self):
+ PIPELINE_REGISTRY.register_pipeline(
+ "pair-classification",
+ pipeline_class=PairClassificationPipeline,
+ pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
+ tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
+ )
+
+ classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert")
+
+ # Clean registry as we won't need the pipeline to be in it for the rest to work.
+ del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ classifier.save_pretrained(tmp_dir)
+ # checks
+ self.assertDictEqual(
+ classifier.model.config.custom_pipelines,
+ {
+ "pair-classification": {
+ "impl": "custom_pipeline.PairClassificationPipeline",
+ "pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (),
+ "tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (),
+ }
+ },
+ )
+ # Fails if the user forget to pass along `trust_remote_code=True`
+ with self.assertRaises(ValueError):
+ _ = pipeline(model=tmp_dir)
+
+ new_classifier = pipeline(model=tmp_dir, trust_remote_code=True)
+ # Using trust_remote_code=False forces the traditional pipeline tag
+ old_classifier = pipeline("text-classification", model=tmp_dir, trust_remote_code=False)
+ # Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
+ # dynamic module
+ self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
+ self.assertEqual(new_classifier.task, "pair-classification")
+ results = new_classifier("I hate you", second_text="I love you")
+ self.assertDictEqual(
+ nested_simplify(results),
+ {"label": "LABEL_0", "score": 0.505, "logits": [-0.003, -0.024]},
+ )
+
+ self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
+ self.assertEqual(old_classifier.task, "text-classification")
+ results = old_classifier("I hate you", text_pair="I love you")
+ self.assertListEqual(
+ nested_simplify(results),
+ [{"label": "LABEL_0", "score": 0.505}],
+ )
+
+
+@require_torch
+@is_staging_test
+class DynamicPipelineTester(unittest.TestCase):
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "I", "love", "hate", "you"]
+
+ @classmethod
+ def setUpClass(cls):
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ delete_repo(token=cls._token, repo_id="test-dynamic-pipeline")
+ except HTTPError:
+ pass
+
+ def test_push_to_hub_dynamic_pipeline(self):
+ from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
+
+ PIPELINE_REGISTRY.register_pipeline(
+ "pair-classification",
+ pipeline_class=PairClassificationPipeline,
+ pt_model=AutoModelForSequenceClassification,
+ )
+
+ config = BertConfig(
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
+ )
+ model = BertForSequenceClassification(config).eval()
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", use_auth_token=self._token)
+
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
+ tokenizer = BertTokenizer(vocab_file)
+
+ classifier = pipeline("pair-classification", model=model, tokenizer=tokenizer)
+
+ # Clean registry as we won't need the pipeline to be in it for the rest to work.
+ del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
+
+ classifier.save_pretrained(tmp_dir)
+ # checks
+ self.assertDictEqual(
+ classifier.model.config.custom_pipelines,
+ {
+ "pair-classification": {
+ "impl": "custom_pipeline.PairClassificationPipeline",
+ "pt": ("AutoModelForSequenceClassification",),
+ "tf": (),
+ }
+ },
+ )
+
+ repo.push_to_hub()
+
+ # Fails if the user forget to pass along `trust_remote_code=True`
+ with self.assertRaises(ValueError):
+ _ = pipeline(model=f"{USER}/test-dynamic-pipeline")
+
+ new_classifier = pipeline(model=f"{USER}/test-dynamic-pipeline", trust_remote_code=True)
+ # Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
+ # dynamic module
+ self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
+
+ results = classifier("I hate you", second_text="I love you")
+ new_results = new_classifier("I hate you", second_text="I love you")
+ self.assertDictEqual(nested_simplify(results), nested_simplify(new_results))
+
+ # Using trust_remote_code=False forces the traditional pipeline tag
+ old_classifier = pipeline(
+ "text-classification", model=f"{USER}/test-dynamic-pipeline", trust_remote_code=False
+ )
+ self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
+ self.assertEqual(old_classifier.task, "text-classification")
+ new_results = old_classifier("I hate you", text_pair="I love you")
+ self.assertListEqual(
+ nested_simplify([{"label": results["label"], "score": results["score"]}]), nested_simplify(new_results)
+ )
diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py
index ed551bf6f490..d85ab8d7ce32 100644
--- a/tests/pipelines/test_pipelines_fill_mask.py
+++ b/tests/pipelines/test_pipelines_fill_mask.py
@@ -16,7 +16,14 @@
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
from transformers.pipelines import PipelineException
-from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
+from transformers.testing_utils import (
+ is_pipeline_test,
+ nested_simplify,
+ require_tf,
+ require_torch,
+ require_torch_gpu,
+ slow,
+)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -130,6 +137,19 @@ def test_small_model_pt(self):
],
)
+ @require_torch_gpu
+ def test_fp16_casting(self):
+ pipe = pipeline("fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=0, framework="pt")
+
+ # convert model to fp16
+ pipe.model.half()
+
+ response = pipe("Paris is the [MASK] of France.")
+ # We actually don't care about the result, we just want to make sure
+ # it works, meaning the float16 tensor got casted back to float32
+ # for postprocessing.
+ self.assertIsInstance(response, list)
+
@slow
@require_torch
def test_large_model_pt(self):
diff --git a/tests/pipelines/test_pipelines_image_segmentation.py b/tests/pipelines/test_pipelines_image_segmentation.py
index fe3ff1ee88f6..1884682ec535 100644
--- a/tests/pipelines/test_pipelines_image_segmentation.py
+++ b/tests/pipelines/test_pipelines_image_segmentation.py
@@ -148,7 +148,7 @@ def test_small_model_tf(self):
@require_torch
def test_small_model_pt(self):
- model_id = "mishig/tiny-detr-mobilenetsv3-panoptic"
+ model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic"
model = AutoModelForImageSegmentation.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
@@ -164,12 +164,12 @@ def test_small_model_pt(self):
[
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],
@@ -192,24 +192,24 @@ def test_small_model_pt(self):
[
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],
[
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
{
"score": 0.004,
- "label": "LABEL_0",
+ "label": "LABEL_215",
"mask": "34eecd16bbfb0f476083ef947d81bf66",
},
],
diff --git a/tests/pipelines/test_pipelines_object_detection.py b/tests/pipelines/test_pipelines_object_detection.py
index d0694d9bdffd..538f31315157 100644
--- a/tests/pipelines/test_pipelines_object_detection.py
+++ b/tests/pipelines/test_pipelines_object_detection.py
@@ -106,7 +106,7 @@ def test_small_model_tf(self):
@require_torch
def test_small_model_pt(self):
- model_id = "mishig/tiny-detr-mobilenetsv3"
+ model_id = "hf-internal-testing/tiny-detr-mobilenetsv3"
model = AutoModelForObjectDetection.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
@@ -117,8 +117,8 @@ def test_small_model_pt(self):
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
)
@@ -134,12 +134,12 @@ def test_small_model_pt(self):
nested_simplify(outputs, decimals=4),
[
[
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
[
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
- {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
+ {"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
],
)
diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py
index e37fa1277683..001254aa94b0 100644
--- a/tests/pipelines/test_pipelines_question_answering.py
+++ b/tests/pipelines/test_pipelines_question_answering.py
@@ -106,17 +106,94 @@ def run_pipeline_test(self, question_answerer, _):
)
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
+ # Using batch is OK
+ new_outputs = question_answerer(
+ question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
+ )
+ self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
+ self.assertEqual(outputs, new_outputs)
+
@require_torch
def test_small_model_pt(self):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
)
+
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
+ @require_torch
+ def test_small_model_pt_iterator(self):
+ # https://github.com/huggingface/transformers/issues/18510
+ pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")
+
+ def data():
+ for i in range(10):
+ yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
+
+ for outputs in pipe(data()):
+ self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
+
+ @require_torch
+ def test_small_model_pt_softmax_trick(self):
+ question_answerer = pipeline(
+ "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
+ )
+
+ real_postprocess = question_answerer.postprocess
+
+ # Tweak start and stop to make sure we encounter the softmax logits
+ # bug.
+ def ensure_large_logits_postprocess(
+ model_outputs,
+ top_k=1,
+ handle_impossible_answer=False,
+ max_answer_len=15,
+ ):
+ for output in model_outputs:
+ output["start"] = output["start"] * 1e6
+ output["end"] = output["end"] * 1e6
+ return real_postprocess(
+ model_outputs,
+ top_k=top_k,
+ handle_impossible_answer=handle_impossible_answer,
+ max_answer_len=max_answer_len,
+ )
+
+ question_answerer.postprocess = ensure_large_logits_postprocess
+
+ outputs = question_answerer(
+ question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
+ )
+
+ self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
+
+ @slow
+ @require_torch
+ def test_small_model_japanese(self):
+ question_answerer = pipeline(
+ "question-answering",
+ model="KoichiYasuoka/deberta-base-japanese-aozora-ud-head",
+ )
+ output = question_answerer(question="å½čŖ", context="å
Øå¦å¹“ć«ććć£ć¦å°å¦ę ”ć®å½čŖć®ęē§ęøć«ęæćēµµćēØćććć¦ćć")
+
+ # Wrong answer, the whole text is identified as one "word" since the tokenizer does not include
+ # a pretokenizer
+ self.assertEqual(
+ nested_simplify(output),
+ {"score": 1.0, "start": 0, "end": 30, "answer": "å
Øå¦å¹“ć«ććć£ć¦å°å¦ę ”ć®å½čŖć®ęē§ęøć«ęæćēµµćēØćććć¦ćć"},
+ )
+
+ # Disable word alignment
+ output = question_answerer(question="å½čŖ", context="å
Øå¦å¹“ć«ććć£ć¦å°å¦ę ”ć®å½čŖć®ęē§ęøć«ęæćēµµćēØćććć¦ćć", align_to_words=False)
+ self.assertEqual(
+ nested_simplify(output),
+ {"score": 1.0, "start": 15, "end": 18, "answer": "ęē§ęø"},
+ )
+
@slow
@require_torch
def test_small_model_long_context_cls_slow(self):
@@ -164,7 +241,42 @@ def test_large_model_issue(self):
)
outputs = qa_pipeline(
{
- "context": "Yes Bank founder Rana Kapoor has approached the Bombay High Court, challenging a special court's order from August this year that had remanded him in police custody for a week in a multi-crore loan fraud case. Kapoor, who is currently lodged in Taloja Jail, is an accused in the loan fraud case and some related matters being probed by the CBI and Enforcement Directorate. A single bench presided over by Justice S K Shinde on Tuesday posted the plea for further hearing on October 14. In his plea filed through advocate Vijay Agarwal, Kapoor claimed that the special court's order permitting the CBI's request for police custody on August 14 was illegal and in breach of the due process of law. Therefore, his police custody and subsequent judicial custody in the case were all illegal. Kapoor has urged the High Court to quash and set aside the special court's order dated August 14. As per his plea, in August this year, the CBI had moved two applications before the special court, one seeking permission to arrest Kapoor, who was already in judicial custody at the time in another case, and the other, seeking his police custody. While the special court refused to grant permission to the CBI to arrest Kapoor, it granted the central agency's plea for his custody. Kapoor, however, said in his plea that before filing an application for his arrest, the CBI had not followed the process of issuing him a notice under Section 41 of the CrPC for appearance before it. He further said that the CBI had not taken prior sanction as mandated under section 17 A of the Prevention of Corruption Act for prosecuting him. The special court, however, had said in its order at the time that as Kapoor was already in judicial custody in another case and was not a free man the procedure mandated under Section 41 of the CrPC need not have been adhered to as far as issuing a prior notice of appearance was concerned. ADVERTISING It had also said that case records showed that the investigating officer had taken an approval from a managing director of Yes Bank before beginning the proceedings against Kapoor and such a permission was a valid sanction. However, Kapoor in his plea said that the above order was bad in law and sought that it be quashed and set aside. The law mandated that if initial action was not in consonance with legal procedures, then all subsequent actions must be held as illegal, he said, urging the High Court to declare the CBI remand and custody and all subsequent proceedings including the further custody as illegal and void ab-initio. In a separate plea before the High Court, Kapoor's daughter Rakhee Kapoor-Tandon has sought exemption from in-person appearance before a special PMLA court. Rakhee has stated that she is a resident of the United Kingdom and is unable to travel to India owing to restrictions imposed due to the COVID-19 pandemic. According to the CBI, in the present case, Kapoor had obtained a gratification or pecuniary advantage of ā¹ 307 crore, and thereby caused Yes Bank a loss of ā¹ 1,800 crore by extending credit facilities to Avantha Group, when it was not eligible for the same",
+ "context": (
+ "Yes Bank founder Rana Kapoor has approached the Bombay High Court, challenging a special court's"
+ " order from August this year that had remanded him in police custody for a week in a multi-crore"
+ " loan fraud case. Kapoor, who is currently lodged in Taloja Jail, is an accused in the loan fraud"
+ " case and some related matters being probed by the CBI and Enforcement Directorate. A single"
+ " bench presided over by Justice S K Shinde on Tuesday posted the plea for further hearing on"
+ " October 14. In his plea filed through advocate Vijay Agarwal, Kapoor claimed that the special"
+ " court's order permitting the CBI's request for police custody on August 14 was illegal and in"
+ " breach of the due process of law. Therefore, his police custody and subsequent judicial custody"
+ " in the case were all illegal. Kapoor has urged the High Court to quash and set aside the special"
+ " court's order dated August 14. As per his plea, in August this year, the CBI had moved two"
+ " applications before the special court, one seeking permission to arrest Kapoor, who was already"
+ " in judicial custody at the time in another case, and the other, seeking his police custody."
+ " While the special court refused to grant permission to the CBI to arrest Kapoor, it granted the"
+ " central agency's plea for his custody. Kapoor, however, said in his plea that before filing an"
+ " application for his arrest, the CBI had not followed the process of issuing him a notice under"
+ " Section 41 of the CrPC for appearance before it. He further said that the CBI had not taken"
+ " prior sanction as mandated under section 17 A of the Prevention of Corruption Act for"
+ " prosecuting him. The special court, however, had said in its order at the time that as Kapoor"
+ " was already in judicial custody in another case and was not a free man the procedure mandated"
+ " under Section 41 of the CrPC need not have been adhered to as far as issuing a prior notice of"
+ " appearance was concerned. ADVERTISING It had also said that case records showed that the"
+ " investigating officer had taken an approval from a managing director of Yes Bank before"
+ " beginning the proceedings against Kapoor and such a permission was a valid sanction. However,"
+ " Kapoor in his plea said that the above order was bad in law and sought that it be quashed and"
+ " set aside. The law mandated that if initial action was not in consonance with legal procedures,"
+ " then all subsequent actions must be held as illegal, he said, urging the High Court to declare"
+ " the CBI remand and custody and all subsequent proceedings including the further custody as"
+ " illegal and void ab-initio. In a separate plea before the High Court, Kapoor's daughter Rakhee"
+ " Kapoor-Tandon has sought exemption from in-person appearance before a special PMLA court. Rakhee"
+ " has stated that she is a resident of the United Kingdom and is unable to travel to India owing"
+ " to restrictions imposed due to the COVID-19 pandemic. According to the CBI, in the present case,"
+ " Kapoor had obtained a gratification or pecuniary advantage of ā¹ 307 crore, and thereby caused"
+ " Yes Bank a loss of ā¹ 1,800 crore by extending credit facilities to Avantha Group, when it was"
+ " not eligible for the same"
+ ),
"question": "Is this person invovled in fraud?",
}
)
diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py
index e434ed742dc7..d797383811c6 100644
--- a/tests/pipelines/test_pipelines_summarization.py
+++ b/tests/pipelines/test_pipelines_summarization.py
@@ -18,6 +18,7 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
LEDConfig,
+ LongT5Config,
SummarizationPipeline,
T5Config,
pipeline,
@@ -54,8 +55,8 @@ def run_pipeline_test(self, summarizer, _):
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
- if not isinstance(model.config, (T5Config, LEDConfig)):
- # LED, T5 can handle it.
+ if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
+ # LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
@@ -91,7 +92,49 @@ def test_small_model_tf(self):
@slow
def test_integration_torch_summarization(self):
summarizer = pipeline(task="summarization", device=DEFAULT_DEVICE_NUM)
- cnn_article = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- expected_cnn_summary = " The Palestinian Authority becomes the 123rd member of the International Criminal Court . The move gives the court jurisdiction over alleged crimes in Palestinian territories . Israel and the United States opposed the Palestinians' efforts to join the court . Rights group Human Rights Watch welcomes the move, says governments seeking to penalize Palestine should end pressure ."
+ cnn_article = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ expected_cnn_summary = (
+ " The Palestinian Authority becomes the 123rd member of the International Criminal Court . The move gives"
+ " the court jurisdiction over alleged crimes in Palestinian territories . Israel and the United States"
+ " opposed the Palestinians' efforts to join the court . Rights group Human Rights Watch welcomes the move,"
+ " says governments seeking to penalize Palestine should end pressure ."
+ )
result = summarizer(cnn_article)
self.assertEqual(result[0]["summary_text"], expected_cnn_summary)
diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py
index 86bbf991b039..ba7fdaa75c50 100644
--- a/tests/pipelines/test_pipelines_table_question_answering.py
+++ b/tests/pipelines/test_pipelines_table_question_answering.py
@@ -92,7 +92,8 @@ def test_small_model_tf(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -194,7 +195,8 @@ def test_small_model_pt(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -313,7 +315,8 @@ def test_slow_tokenizer_sqa_pt(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -434,7 +437,8 @@ def test_slow_tokenizer_sqa_tf(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py
index 39deed9bee55..6bbc84989a21 100644
--- a/tests/pipelines/test_pipelines_text_classification.py
+++ b/tests/pipelines/test_pipelines_text_classification.py
@@ -39,6 +39,64 @@ def test_small_model_pt(self):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+ outputs = text_classifier("This is great !", top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
+ )
+
+ outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ ],
+ )
+
+ outputs = text_classifier("This is great !", top_k=1)
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
+ # Legacy behavior
+ outputs = text_classifier("This is great !", return_all_scores=False)
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
+ outputs = text_classifier("This is great !", return_all_scores=True)
+ self.assertEqual(
+ nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
+ )
+
+ outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ ],
+ )
+
+ outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ {"label": "LABEL_0", "score": 0.504},
+ {"label": "LABEL_0", "score": 0.504},
+ ],
+ )
+
+ @require_torch
+ def test_accepts_torch_device(self):
+ import torch
+
+ text_classifier = pipeline(
+ task="text-classification",
+ model="hf-internal-testing/tiny-random-distilbert",
+ framework="pt",
+ device=torch.device("cpu"),
+ )
+
+ outputs = text_classifier("This is great !")
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(
@@ -93,3 +151,37 @@ def run_pipeline_test(self, text_classifier, _):
)
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
+
+ # Forcing to get all results with `top_k=None`
+ # This is NOT the legacy format
+ outputs = text_classifier(valid_inputs, top_k=None)
+ N = len(model.config.id2label.values())
+ self.assertEqual(
+ nested_simplify(outputs),
+ [[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
+ )
+
+ valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
+ outputs = text_classifier(valid_inputs)
+ self.assertEqual(
+ nested_simplify(outputs),
+ {"label": ANY(str), "score": ANY(float)},
+ )
+ self.assertTrue(outputs["label"] in model.config.id2label.values())
+
+ # This might be used a text pair, but tokenizer + pipe interaction
+ # makes it hard to understand that it's not using the pair properly
+ # https://github.com/huggingface/transformers/issues/17305
+ # We disabled this usage instead as it was outputting wrong outputs.
+ invalid_input = [["HuggingFace is in ", "Paris is in France"]]
+ with self.assertRaises(ValueError):
+ text_classifier(invalid_input)
+
+ # This used to be valid for doing text pairs
+ # We're keeping it working because of backward compatibility
+ outputs = text_classifier([[["HuggingFace is in ", "Paris is in France"]]])
+ self.assertEqual(
+ nested_simplify(outputs),
+ [{"label": ANY(str), "score": ANY(float)}],
+ )
+ self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py
index ca67c3bea13d..a26ed56d4cd4 100644
--- a/tests/pipelines/test_pipelines_text_generation.py
+++ b/tests/pipelines/test_pipelines_text_generation.py
@@ -15,7 +15,13 @@
import unittest
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
-from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
+from transformers.testing_utils import (
+ is_pipeline_test,
+ require_accelerate,
+ require_tf,
+ require_torch,
+ require_torch_gpu,
+)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -34,7 +40,10 @@ def test_small_model_pt(self):
outputs,
[
{
- "generated_text": "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope."
+ " oscope. FiliFili@@"
+ )
}
],
)
@@ -45,12 +54,18 @@ def test_small_model_pt(self):
[
[
{
- "generated_text": "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope."
+ " oscope. FiliFili@@"
+ )
}
],
[
{
- "generated_text": "This is a second test ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a second test ā segmental segmental segmental 议议eski eski flutter flutter Lacy"
+ " oscope. oscope. FiliFili@@"
+ )
}
],
],
@@ -97,7 +112,10 @@ def test_small_model_tf(self):
outputs,
[
{
- "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ"
+ " please,"
+ )
}
],
)
@@ -108,12 +126,18 @@ def test_small_model_tf(self):
[
[
{
- "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ"
+ " please,"
+ )
}
],
[
{
- "generated_text": "This is a second test Chieftain Chieftain prefecture prefecture prefecture Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a second test Chieftain Chieftain prefecture prefecture prefecture Cannes Cannes"
+ " Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ )
}
],
],
@@ -197,3 +221,63 @@ def run_pipeline_test(self, text_generator, _):
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
+
+ @require_torch
+ @require_accelerate
+ @require_torch_gpu
+ def test_small_model_pt_bloom_accelerate(self):
+ import torch
+
+ # Classic `model_kwargs`
+ pipe = pipeline(
+ model="hf-internal-testing/tiny-random-bloom",
+ model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
+ )
+ self.assertEqual(pipe.model.device, torch.device(0))
+ self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
+ out = pipe("This is a test")
+ self.assertEqual(
+ out,
+ [
+ {
+ "generated_text": (
+ "This is a test test test test test test test test test test test test test test test test"
+ " test"
+ )
+ }
+ ],
+ )
+
+ # Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
+ pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
+ self.assertEqual(pipe.model.device, torch.device(0))
+ self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
+ out = pipe("This is a test")
+ self.assertEqual(
+ out,
+ [
+ {
+ "generated_text": (
+ "This is a test test test test test test test test test test test test test test test test"
+ " test"
+ )
+ }
+ ],
+ )
+
+ # torch_dtype not necessary
+ pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
+ self.assertEqual(pipe.model.device, torch.device(0))
+ self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
+ out = pipe("This is a test")
+ self.assertEqual(
+ out,
+ [
+ {
+ "generated_text": (
+ "This is a test test test test test test test test test test test test test test test test"
+ " test"
+ )
+ }
+ ],
+ )
diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py
index 26cfa0d3be34..bc4eaef06255 100644
--- a/tests/pipelines/test_pipelines_token_classification.py
+++ b/tests/pipelines/test_pipelines_token_classification.py
@@ -278,15 +278,15 @@ def test_dbmdz_english(self):
NER_MODEL = "dbmdz/bert-large-cased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
- sentence = """Enzo works at the the UN"""
+ sentence = """Enzo works at the UN"""
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
- {"entity": "I-PER", "score": 0.997, "word": "En", "start": 0, "end": 2, "index": 1},
- {"entity": "I-PER", "score": 0.996, "word": "##zo", "start": 2, "end": 4, "index": 2},
- {"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24, "index": 7},
+ {"entity": "I-PER", "score": 0.998, "word": "En", "start": 0, "end": 2, "index": 1},
+ {"entity": "I-PER", "score": 0.997, "word": "##zo", "start": 2, "end": 4, "index": 2},
+ {"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 18, "end": 20, "index": 6},
],
)
@@ -295,8 +295,8 @@ def test_dbmdz_english(self):
self.assertEqual(
nested_simplify(output),
[
- {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
- {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
+ {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
+ {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 18, "end": 20},
],
)
@@ -305,8 +305,8 @@ def test_dbmdz_english(self):
self.assertEqual(
nested_simplify(output[:3]),
[
- {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
- {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
+ {"entity_group": "PER", "score": 0.998, "word": "Enzo", "start": 0, "end": 4},
+ {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 18, "end": 20},
],
)
@@ -315,8 +315,8 @@ def test_dbmdz_english(self):
self.assertEqual(
nested_simplify(output[:3]),
[
- {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
- {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
+ {"entity_group": "PER", "score": 0.998, "word": "Enzo", "start": 0, "end": 4},
+ {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 18, "end": 20},
],
)
@@ -325,8 +325,8 @@ def test_dbmdz_english(self):
self.assertEqual(
nested_simplify(output),
[
- {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
- {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
+ {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
+ {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 18, "end": 20},
],
)
@@ -535,6 +535,20 @@ def test_aggregation_strategy_example2(self):
[{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}],
)
+ @require_torch
+ @slow
+ def test_aggregation_strategy_offsets_with_leading_space(self):
+ sentence = "We're from New York"
+ model_name = "brandon25/deberta-base-finetuned-ner"
+ ner = pipeline("ner", model=model_name, ignore_labels=[], aggregation_strategy="max")
+ self.assertEqual(
+ nested_simplify(ner(sentence)),
+ [
+ {"entity_group": "O", "score": 1.0, "word": " We're from", "start": 0, "end": 10},
+ {"entity_group": "LOC", "score": 1.0, "word": " New York", "start": 10, "end": 19},
+ ],
+ )
+
@require_torch
def test_gather_pre_entities(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
@@ -580,6 +594,41 @@ def test_gather_pre_entities(self):
],
)
+ @require_torch
+ def test_word_heuristic_leading_space(self):
+ model_name = "hf-internal-testing/tiny-random-deberta-v2"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+ token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
+
+ sentence = "I play the theremin"
+
+ tokens = tokenizer(
+ sentence,
+ return_attention_mask=False,
+ return_tensors="pt",
+ return_special_tokens_mask=True,
+ return_offsets_mapping=True,
+ )
+ offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
+ special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
+ input_ids = tokens["input_ids"].numpy()[0]
+ scores = np.array([[1, 0] for _ in input_ids]) # values irrelevant for heuristic
+
+ pre_entities = token_classifier.gather_pre_entities(
+ sentence,
+ input_ids,
+ scores,
+ offset_mapping,
+ special_tokens_mask,
+ aggregation_strategy=AggregationStrategy.FIRST,
+ )
+
+ # ensure expected tokenization and correct is_subword values
+ self.assertEqual(
+ [(entity["word"], entity["is_subword"]) for entity in pre_entities],
+ [("āI", False), ("āplay", False), ("āthe", False), ("āthere", False), ("min", True)],
+ )
+
@require_tf
def test_tf_only(self):
model_name = "hf-internal-testing/tiny-random-bert-tf-only" # This model only has a TensorFlow version
diff --git a/tests/pipelines/test_pipelines_translation.py b/tests/pipelines/test_pipelines_translation.py
index 368f6bc9c5cc..3c5999f36e60 100644
--- a/tests/pipelines/test_pipelines_translation.py
+++ b/tests/pipelines/test_pipelines_translation.py
@@ -61,7 +61,10 @@ def test_small_model_pt(self):
outputs,
[
{
- "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ "translation_text": (
+ "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ " Beide Beide"
+ )
}
],
)
@@ -74,7 +77,10 @@ def test_small_model_tf(self):
outputs,
[
{
- "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ "translation_text": (
+ "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ " Beide Beide"
+ )
}
],
)
@@ -87,7 +93,10 @@ def test_en_to_de_pt(self):
outputs,
[
{
- "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
+ "translation_text": (
+ "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine"
+ " urine urine urine urine urine urine urine"
+ )
}
],
)
@@ -100,7 +109,10 @@ def test_en_to_de_tf(self):
outputs,
[
{
- "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
+ "translation_text": (
+ "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine"
+ " urine urine urine urine urine urine urine"
+ )
}
],
)
diff --git a/tests/pipelines/test_pipelines_visual_question_answering.py b/tests/pipelines/test_pipelines_visual_question_answering.py
new file mode 100644
index 000000000000..d3315681f47e
--- /dev/null
+++ b/tests/pipelines/test_pipelines_visual_question_answering.py
@@ -0,0 +1,115 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_available
+from transformers.pipelines import pipeline
+from transformers.testing_utils import (
+ is_pipeline_test,
+ nested_simplify,
+ require_tf,
+ require_torch,
+ require_vision,
+ slow,
+)
+
+from .test_pipelines_common import ANY, PipelineTestCaseMeta
+
+
+if is_vision_available():
+ from PIL import Image
+else:
+
+ class Image:
+ @staticmethod
+ def open(*args, **kwargs):
+ pass
+
+
+@is_pipeline_test
+@require_torch
+@require_vision
+class VisualQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
+ model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
+
+ def get_test_pipeline(self, model, tokenizer, feature_extractor):
+ vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
+ examples = [
+ {
+ "image": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
+ "question": "How many cats are there?",
+ },
+ {
+ "image": "./tests/fixtures/tests_samples/COCO/000000039769.png",
+ "question": "How many cats are there?",
+ },
+ ]
+ return vqa_pipeline, examples
+
+ def run_pipeline_test(self, vqa_pipeline, examples):
+ outputs = vqa_pipeline(examples, top_k=1)
+ self.assertEqual(
+ outputs,
+ [
+ [{"score": ANY(float), "answer": ANY(str)}],
+ [{"score": ANY(float), "answer": ANY(str)}],
+ ],
+ )
+
+ @require_torch
+ def test_small_model_pt(self):
+ vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
+ image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
+ question = "How many cats are there?"
+
+ outputs = vqa_pipeline(image=image, question="How many cats are there?", top_k=2)
+ self.assertEqual(
+ outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
+ )
+
+ outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
+ self.assertEqual(
+ outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
+ )
+
+ @slow
+ @require_torch
+ def test_large_model_pt(self):
+ vqa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
+ image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
+ question = "How many cats are there?"
+
+ outputs = vqa_pipeline(image=image, question=question, top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
+ )
+
+ outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
+ )
+
+ outputs = vqa_pipeline(
+ [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
+ )
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4),
+ [[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
+ )
+
+ @require_tf
+ @unittest.skip("Visual question answering not implemented in TF")
+ def test_small_model_tf(self):
+ pass
diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py
index ed564581e526..af98ac020172 100644
--- a/tests/pipelines/test_pipelines_zero_shot.py
+++ b/tests/pipelines/test_pipelines_zero_shot.py
@@ -202,14 +202,39 @@ def test_large_model_pt(self):
},
)
outputs = zero_shot_classifier(
- "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
+ " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
+ " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
+ " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
+ " machine translation tasks show these models to be superior in quality while being more parallelizable"
+ " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
+ " English-to-German translation task, improving over the existing best results, including ensembles by"
+ " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
+ " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
+ " fraction of the training costs of the best models from the literature. We show that the Transformer"
+ " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
+ " large and limited training data.",
candidate_labels=["machine learning", "statistics", "translation", "vision"],
multi_label=True,
)
self.assertEqual(
nested_simplify(outputs),
{
- "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "sequence": (
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
+ " networks in an encoder-decoder configuration. The best performing models also connect the"
+ " encoder and decoder through an attention mechanism. We propose a new simple network"
+ " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
+ " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
+ " superior in quality while being more parallelizable and requiring significantly less time to"
+ " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
+ " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
+ " English-to-French translation task, our model establishes a new single-model state-of-the-art"
+ " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
+ " costs of the best models from the literature. We show that the Transformer generalizes well to"
+ " other tasks by applying it successfully to English constituency parsing both with large and"
+ " limited training data."
+ ),
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.713, 0.018, 0.018],
},
@@ -232,14 +257,39 @@ def test_large_model_tf(self):
},
)
outputs = zero_shot_classifier(
- "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
+ " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
+ " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
+ " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
+ " machine translation tasks show these models to be superior in quality while being more parallelizable"
+ " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
+ " English-to-German translation task, improving over the existing best results, including ensembles by"
+ " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
+ " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
+ " fraction of the training costs of the best models from the literature. We show that the Transformer"
+ " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
+ " large and limited training data.",
candidate_labels=["machine learning", "statistics", "translation", "vision"],
multi_label=True,
)
self.assertEqual(
nested_simplify(outputs),
{
- "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "sequence": (
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
+ " networks in an encoder-decoder configuration. The best performing models also connect the"
+ " encoder and decoder through an attention mechanism. We propose a new simple network"
+ " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
+ " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
+ " superior in quality while being more parallelizable and requiring significantly less time to"
+ " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
+ " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
+ " English-to-French translation task, our model establishes a new single-model state-of-the-art"
+ " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
+ " costs of the best models from the literature. We show that the Transformer generalizes well to"
+ " other tasks by applying it successfully to English constituency parsing both with large and"
+ " limited training data."
+ ),
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.713, 0.018, 0.018],
},
diff --git a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
index 6bec48fda7ad..01185fdabac5 100644
--- a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
+++ b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
@@ -81,8 +81,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -91,29 +93,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of test examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -170,8 +180,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py
index 853a19c3ec84..397346c7deec 100644
--- a/tests/test_configuration_common.py
+++ b/tests/test_configuration_common.py
@@ -23,11 +23,11 @@
import unittest.mock as mock
from pathlib import Path
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig
-from transformers.testing_utils import PASS, USER, is_staging_test
+from transformers.testing_utils import TOKEN, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -42,6 +42,7 @@
"torchscript": True,
"torch_dtype": "float16",
"use_bfloat16": True,
+ "tf_legacy_loss": True,
"pruned_heads": {"a": 1},
"tie_word_embeddings": False,
"is_decoder": True,
@@ -156,6 +157,17 @@ def create_and_test_config_from_and_save_pretrained(self):
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
+ def create_and_test_config_from_and_save_pretrained_subfolder(self):
+ config_first = self.config_class(**self.inputs_dict)
+
+ subfolder = "test"
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ sub_tmpdirname = os.path.join(tmpdirname, subfolder)
+ config_first.save_pretrained(sub_tmpdirname)
+ config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)
+
+ self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
+
def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5)
@@ -196,6 +208,7 @@ def run_common_tests(self):
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
+ self.create_and_test_config_from_and_save_pretrained_subfolder()
self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
self.check_config_arguments_init()
@@ -205,22 +218,24 @@ def run_common_tests(self):
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-config")
+ delete_repo(token=cls._token, repo_id="test-config")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-config")
+ delete_repo(token=cls._token, repo_id="test-dynamic-config")
except HTTPError:
pass
@@ -228,46 +243,58 @@ def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
+ config.push_to_hub("test-config", use_auth_token=self._token)
+
+ new_config = BertConfig.from_pretrained(f"{USER}/test-config")
+ for k, v in config.__dict__.items():
+ if k != "transformers_version":
+ self.assertEqual(v, getattr(new_config, k))
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-config")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
- config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
+ config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
- new_config = BertConfig.from_pretrained(f"{USER}/test-config")
- for k, v in config.__dict__.items():
- if k != "transformers_version":
- self.assertEqual(v, getattr(new_config, k))
+ new_config = BertConfig.from_pretrained(f"{USER}/test-config")
+ for k, v in config.__dict__.items():
+ if k != "transformers_version":
+ self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
+ config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
+
+ new_config = BertConfig.from_pretrained("valid_org/test-config-org")
+ for k, v in config.__dict__.items():
+ if k != "transformers_version":
+ self.assertEqual(v, getattr(new_config, k))
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-config-org")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
- os.path.join(tmp_dir, "test-config-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, use_auth_token=self._token
)
- new_config = BertConfig.from_pretrained("valid_org/test-config-org")
- for k, v in config.__dict__.items():
- if k != "transformers_version":
- self.assertEqual(v, getattr(new_config, k))
+ new_config = BertConfig.from_pretrained("valid_org/test-config-org")
+ for k, v in config.__dict__.items():
+ if k != "transformers_version":
+ self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
- config.save_pretrained(tmp_dir)
-
- # This has added the proper auto_map field to the config
- self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
- # The code has been copied from fixtures
- self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py")))
+ config.push_to_hub("test-dynamic-config", use_auth_token=self._token)
- repo.push_to_hub()
+ # This has added the proper auto_map field to the config
+ self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
@@ -300,22 +327,32 @@ def test_config_common_kwargs_is_complete(self):
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
- "The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
- f"pick another value for them: {', '.join(keys_with_defaults)}."
+ "The following keys are set with the default values in"
+ " `test_configuration_common.config_common_kwargs` pick another value for them:"
+ f" {', '.join(keys_with_defaults)}."
)
+ def test_from_pretrained_subfolder(self):
+ with self.assertRaises(OSError):
+ # config is in subfolder, the following should not work without specifying the subfolder
+ _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
+
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
+
+ self.assertIsNotNone(config)
+
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py
index e95db7eae860..3ecf89a90867 100644
--- a/tests/test_feature_extraction_common.py
+++ b/tests/test_feature_extraction_common.py
@@ -22,10 +22,10 @@
import unittest.mock as mock
from pathlib import Path
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
-from transformers.testing_utils import PASS, USER, get_tests_dir, is_staging_test
+from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.utils import is_torch_available, is_vision_available
@@ -48,44 +48,91 @@
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
+
+ One can specify whether the images are of the same resolution or not.
"""
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
- if equal_resolution:
- image_inputs = []
- for i in range(feature_extract_tester.batch_size):
- image_inputs.append(
- np.random.randint(
- 255,
- size=(
- feature_extract_tester.num_channels,
- feature_extract_tester.max_resolution,
- feature_extract_tester.max_resolution,
- ),
- dtype=np.uint8,
- )
- )
- else:
- image_inputs = []
- for i in range(feature_extract_tester.batch_size):
- width, height = np.random.choice(
- np.arange(feature_extract_tester.min_resolution, feature_extract_tester.max_resolution), 2
- )
- image_inputs.append(
- np.random.randint(255, size=(feature_extract_tester.num_channels, width, height), dtype=np.uint8)
+ image_inputs = []
+ for i in range(feature_extract_tester.batch_size):
+ if equal_resolution:
+ width = height = feature_extract_tester.max_resolution
+ else:
+ # To avoid getting image width/height 0
+ min_resolution = feature_extract_tester.min_resolution
+ if getattr(feature_extract_tester, "size_divisor", None):
+ # If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
+ min_resolution = max(feature_extract_tester.size_divisor, min_resolution)
+ width, height = np.random.choice(np.arange(min_resolution, feature_extract_tester.max_resolution), 2)
+ image_inputs.append(
+ np.random.randint(
+ 255,
+ size=(
+ feature_extract_tester.num_channels,
+ width,
+ height,
+ ),
+ dtype=np.uint8,
)
+ )
if not numpify and not torchify:
# PIL expects the channel dimension as last dimension
- image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+ image_inputs = [Image.fromarray(np.moveaxis(image, 0, -1)) for image in image_inputs]
if torchify:
- image_inputs = [torch.from_numpy(x) for x in image_inputs]
+ image_inputs = [torch.from_numpy(image) for image in image_inputs]
return image_inputs
+def prepare_video(feature_extract_tester, width=10, height=10, numpify=False, torchify=False):
+ """This function prepares a video as a list of PIL images/NumPy arrays/PyTorch tensors."""
+
+ video = []
+ for i in range(feature_extract_tester.num_frames):
+ video.append(np.random.randint(255, size=(feature_extract_tester.num_channels, width, height), dtype=np.uint8))
+
+ if not numpify and not torchify:
+ # PIL expects the channel dimension as last dimension
+ video = [Image.fromarray(np.moveaxis(frame, 0, -1)) for frame in video]
+
+ if torchify:
+ video = [torch.from_numpy(frame) for frame in video]
+
+ return video
+
+
+def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
+ """This function prepares a batch of videos: a list of list of PIL images, or a list of list of numpy arrays if
+ one specifies numpify=True, or a list of list of PyTorch tensors if one specifies torchify=True.
+
+ One can specify whether the videos are of the same resolution or not.
+ """
+
+ assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
+
+ video_inputs = []
+ for i in range(feature_extract_tester.batch_size):
+ if equal_resolution:
+ width = height = feature_extract_tester.max_resolution
+ else:
+ width, height = np.random.choice(
+ np.arange(feature_extract_tester.min_resolution, feature_extract_tester.max_resolution), 2
+ )
+ video = prepare_video(
+ feature_extract_tester=feature_extract_tester,
+ width=width,
+ height=height,
+ numpify=numpify,
+ torchify=torchify,
+ )
+ video_inputs.append(video)
+
+ return video_inputs
+
+
class FeatureExtractionSavingTestMixin:
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
@@ -107,7 +154,8 @@ def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
- feat_extract_first.save_pretrained(tmpdirname)
+ saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict())
@@ -122,13 +170,13 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
@@ -138,68 +186,80 @@ def test_cached_files_are_used_when_internet_is_down(self):
class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-feature-extractor")
+ delete_repo(token=cls._token, repo_id="test-feature-extractor")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
+ delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
+ feature_extractor.push_to_hub("test-feature-extractor", use_auth_token=self._token)
+
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
+ for k, v in feature_extractor.__dict__.items():
+ self.assertEqual(v, getattr(new_feature_extractor, k))
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-feature-extractor")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
- os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
+ tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, use_auth_token=self._token
)
- new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
- for k, v in feature_extractor.__dict__.items():
- self.assertEqual(v, getattr(new_feature_extractor, k))
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
+ for k, v in feature_extractor.__dict__.items():
+ self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
+ feature_extractor.push_to_hub("valid_org/test-feature-extractor", use_auth_token=self._token)
+
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
+ for k, v in feature_extractor.__dict__.items():
+ self.assertEqual(v, getattr(new_feature_extractor, k))
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
- os.path.join(tmp_dir, "test-feature-extractor-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, use_auth_token=self._token
)
- new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
- for k, v in feature_extractor.__dict__.items():
- self.assertEqual(v, getattr(new_feature_extractor, k))
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
+ for k, v in feature_extractor.__dict__.items():
+ self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
- feature_extractor.save_pretrained(tmp_dir)
-
- # This has added the proper auto_map field to the config
- self.assertDictEqual(
- feature_extractor.auto_map,
- {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
- )
- # The code has been copied from fixtures
- self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
+ feature_extractor.push_to_hub("test-dynamic-feature-extractor", use_auth_token=self._token)
- repo.push_to_hub()
+ # This has added the proper auto_map field to the config
+ self.assertDictEqual(
+ feature_extractor.auto_map,
+ {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
+ )
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index ac45a1c10822..8f80d7fa42f7 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -19,6 +19,7 @@
import json
import os
import os.path
+import pickle
import random
import sys
import tempfile
@@ -31,7 +32,7 @@
import numpy as np
import transformers
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import (
AutoConfig,
@@ -43,14 +44,16 @@
)
from transformers.models.auto import get_values
from transformers.testing_utils import (
- PASS,
+ TOKEN,
USER,
CaptureLogger,
TestCasePlus,
is_pt_flax_cross_test,
is_pt_tf_cross_test,
is_staging_test,
+ require_accelerate,
require_torch,
+ require_torch_gpu,
require_torch_multi_gpu,
require_usr_bin_time,
slow,
@@ -59,6 +62,7 @@
from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
+ is_accelerate_available,
is_flax_available,
is_tf_available,
is_torch_fx_available,
@@ -71,6 +75,10 @@
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
+if is_accelerate_available():
+ from accelerate.utils import compute_module_sizes
+
+
if is_torch_available():
import torch
from torch import nn
@@ -91,8 +99,11 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
+ AutoModelForCausalLM,
+ AutoTokenizer,
BertConfig,
BertModel,
PreTrainedModel,
@@ -124,6 +135,7 @@ def _config_zero_init(config):
TINY_T5 = "patrickvonplaten/t5-tiny-random"
+TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@require_torch
@@ -143,6 +155,7 @@ class ModelTesterMixin:
test_model_parallel = False
is_encoder_decoder = False
has_attentions = True
+ model_split_percents = [0.5, 0.7, 0.9]
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
@@ -170,6 +183,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
+ *get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
@@ -474,123 +488,119 @@ def test_training_gradient_checkpointing(self):
loss.backward()
def test_attention_outputs(self):
- if not self.has_attentions:
- pass
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
- else:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- seq_len = getattr(self.model_tester, "seq_length", None)
- decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
- chunk_length = getattr(self.model_tester, "chunk_length", None)
- if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
- encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- if chunk_length is not None:
- self.assertListEqual(
- list(attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- if self.is_encoder_decoder:
- correct_outlen = 5
-
- # loss is at first position
- if "labels" in inputs_dict:
- correct_outlen += 1 # loss is added to beginning
- # Question Answering model returns start_logits and end_logits
- if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
- correct_outlen += 1 # start_logits and end_logits instead of only 1 output
- if "past_key_values" in outputs:
- correct_outlen += 1 # past_key_values have been returned
-
- self.assertEqual(out_len, correct_outlen)
-
- # decoder attentions
- decoder_attentions = outputs.decoder_attentions
- self.assertIsInstance(decoder_attentions, (list, tuple))
- self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(decoder_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
- )
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- # cross attentions
- cross_attentions = outputs.cross_attentions
- self.assertIsInstance(cross_attentions, (list, tuple))
- self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(cross_attentions[0].shape[-3:]),
- [
- self.model_tester.num_attention_heads,
- decoder_seq_length,
- encoder_key_length,
- ],
- )
+ if chunk_length is not None:
+ self.assertListEqual(
+ list(attentions[0].shape[-4:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
+ )
+ else:
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
- if hasattr(self.model_tester, "num_hidden_states_types"):
- added_hidden_states = self.model_tester.num_hidden_states_types
- elif self.is_encoder_decoder:
- added_hidden_states = 2
- else:
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- if chunk_length is not None:
- self.assertListEqual(
- list(self_attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ if chunk_length is not None:
+ self.assertListEqual(
+ list(self_attentions[0].shape[-4:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
+ )
+ else:
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
@slow
def test_torchscript_simple(self):
@@ -640,6 +650,13 @@ def _create_and_check_torchscript(self, config, inputs_dict):
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
+ elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
+ input_ids = inputs["input_ids"]
+ bbox = inputs["bbox"]
+ image = inputs["image"].tensor
+ traced_model = torch.jit.trace(
+ model, (input_ids, bbox, image), check_trace=False
+ ) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
traced_model = torch.jit.trace(model, main_input)
@@ -728,18 +745,36 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
- input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
+ input_names = [
+ "attention_mask",
+ "decoder_attention_mask",
+ "decoder_input_ids",
+ "input_features",
+ "input_ids",
+ "input_values",
+ ]
if labels is not None:
input_names.append("labels")
+
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
- input_names = ["input_ids", "attention_mask", "token_type_ids"]
- input_ids = inputs["input_ids"]
+ input_names = [
+ "attention_mask",
+ "bbox",
+ "input_features",
+ "input_ids",
+ "input_values",
+ "pixel_values",
+ "token_type_ids",
+ "visual_feats",
+ "visual_pos",
+ ]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
@@ -752,21 +787,22 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
- input_names = filtered_inputs.keys()
+ input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
- rank = len(input_ids.shape)
- if rank not in [2, 3]:
- raise NotImplementedError(
- f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
- )
+ if (
+ isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
+ and not hasattr(model.config, "problem_type")
+ or model.config.problem_type is None
+ ):
+ model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
- except RuntimeError:
- self.fail("Couldn't trace module.")
+ except Exception as e:
+ self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
@@ -789,6 +825,40 @@ def flatten_output(output):
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
+ # Test that the model can be TorchScripted
+ try:
+ scripted = torch.jit.script(traced_model)
+ except Exception as e:
+ self.fail(f"Could not TorchScript the traced model: {e}")
+ scripted_output = scripted(**filtered_inputs)
+ scripted_output = flatten_output(scripted_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], scripted_output[i]),
+ f"scripted {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
def test_headmasking(self):
if not self.test_head_masking:
return
@@ -1447,7 +1517,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -1550,7 +1625,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
- """Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
+ """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
@@ -1576,13 +1651,13 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
# TODO: remove this method and this line after issues are fixed
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
- tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
- pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
+ tf_keys = [k for k, v in tf_outputs.items() if v is not None]
+ pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
# convert to the case of `tuple`
- # appending each key to the current (string) `names`
+ # appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in tf_keys])
self.check_pt_tf_outputs(
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
@@ -1598,10 +1673,10 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
self.assertEqual(
len(attributes),
len(tf_outputs),
- f"{name}: The tuple `names` should have the same length as `tf_outputs`",
+ f"{name}: The tuple `attributes` should have the same length as `tf_outputs`",
)
else:
- # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
+ # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
@@ -1633,10 +1708,11 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
- self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
+ self.assertLessEqual(max_diff, tol, f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}).")
else:
raise ValueError(
- f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
+ "`tf_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
)
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
@@ -1771,7 +1847,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
- def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@@ -1781,24 +1857,71 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
- if type(fx_outputs) in [tuple, list]:
- self.assertEqual(type(fx_outputs), type(pt_outputs))
- self.assertEqual(len(fx_outputs), len(pt_outputs))
- if type(names) == tuple:
- for fo, po, name in zip(fx_outputs, pt_outputs, names):
- self.check_outputs(fo, po, model_class, names=name)
- elif type(names) == str:
- for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
- self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
+
+ self.assertEqual(type(name), str)
+ if attributes is not None:
+ self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
+
+ # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
+ if isinstance(fx_outputs, ModelOutput):
+ self.assertTrue(
+ isinstance(pt_outputs, ModelOutput),
+ f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
+ )
+
+ fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
+ pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
+
+ self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
+
+ # convert to the case of `tuple`
+ # appending each key to the current (string) `name`
+ attributes = tuple([f"{name}.{k}" for k in fx_keys])
+ self.check_pt_flax_outputs(
+ fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
+ )
+
+ # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
+ elif type(fx_outputs) in [tuple, list]:
+ self.assertEqual(
+ type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
+ )
+ self.assertEqual(
+ len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
+ )
+
+ if attributes is not None:
+ # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
+ self.assertEqual(
+ len(attributes),
+ len(fx_outputs),
+ f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
+ )
else:
- raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
+ # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
+ attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
+
+ for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
+ self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
+
elif isinstance(fx_outputs, jnp.ndarray):
- self.assertTrue(isinstance(pt_outputs, torch.Tensor))
+ self.assertTrue(
+ isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
+ )
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
+ self.assertEqual(
+ fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
+ )
+
+ # deal with NumPy's scalars to make replacing nan values by 0 work.
+ if np.isscalar(fx_outputs):
+ fx_outputs = np.array([fx_outputs])
+ pt_outputs = np.array([pt_outputs])
+
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
@@ -1807,10 +1930,14 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
- self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
+ max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
+ self.assertLessEqual(
+ max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
+ )
else:
raise ValueError(
- f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
+ "`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
+ f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
@@ -1871,7 +1998,7 @@ def test_equivalence_pt_to_flax(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@@ -1883,7 +2010,7 @@ def test_equivalence_pt_to_flax(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@@ -1945,7 +2072,7 @@ def test_equivalence_flax_to_pt(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@@ -1962,7 +2089,7 @@ def test_equivalence_flax_to_pt(self):
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -2066,7 +2193,7 @@ def get_current_gpu_memory_use():
memory_after_parallelization = get_current_gpu_memory_use()
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
- for n in range(torch.cuda.device_count()):
+ for n in range(len(model.device_map.keys())):
self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
# Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
@@ -2142,6 +2269,115 @@ def cast_to_device(dictionary, device):
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
+ def check_device_map_is_respected(self, model, device_map):
+ for param_name, param in model.named_parameters():
+ # Find device in device_map
+ while len(param_name) > 0 and param_name not in device_map:
+ param_name = ".".join(param_name.split(".")[:-1])
+ if param_name not in device_map:
+ raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
+
+ param_device = device_map[param_name]
+ if param_device in ["cpu", "disk"]:
+ self.assertEqual(param.device, torch.device("meta"))
+ else:
+ self.assertEqual(param.device, torch.device(param_device))
+
+ @require_accelerate
+ @require_torch_gpu
+ def test_disk_offload(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ max_size = int(self.model_split_percents[0] * model_size)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ max_memory = {0: max_size, "cpu": max_size}
+ with self.assertRaises(ValueError):
+ # This errors out cause it's missing an offload folder
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+
+ new_model = model_class.from_pretrained(
+ tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
+ )
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
+ @require_accelerate
+ @require_torch_gpu
+ def test_cpu_offload(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ # We test several splits of sizes to make sure it works.
+ max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ for max_size in max_gpu_sizes:
+ max_memory = {0: max_size, "cpu": model_size * 2}
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+ # Making sure part of the model will actually end up offloaded
+ self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
+ @require_accelerate
+ @require_torch_multi_gpu
+ def test_model_parallelism(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ # We test several splits of sizes to make sure it works.
+ max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ for max_size in max_gpu_sizes:
+ max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+ # Making sure part of the model will actually end up offloaded
+ self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -2276,6 +2512,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
+def check_models_equal(model1, model2):
+ models_are_equal = True
+ for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
+ if model1_p.data.ne(model2_p.data).sum() > 0:
+ models_are_equal = False
+
+ return models_are_equal
+
+
@require_torch
class ModelUtilsTest(TestCasePlus):
@slow
@@ -2304,6 +2549,56 @@ def test_model_from_pretrained(self):
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
+ def test_model_from_pretrained_subfolder(self):
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
+ model = BertModel(config)
+
+ subfolder = "bert"
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(os.path.join(tmp_dir, subfolder))
+
+ with self.assertRaises(OSError):
+ _ = BertModel.from_pretrained(tmp_dir)
+
+ model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
+
+ self.assertTrue(check_models_equal(model, model_loaded))
+
+ def test_model_from_pretrained_subfolder_sharded(self):
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
+ model = BertModel(config)
+
+ subfolder = "bert"
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
+
+ with self.assertRaises(OSError):
+ _ = BertModel.from_pretrained(tmp_dir)
+
+ model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
+
+ self.assertTrue(check_models_equal(model, model_loaded))
+
+ def test_model_from_pretrained_hub_subfolder(self):
+ subfolder = "bert"
+ model_id = "hf-internal-testing/tiny-random-bert-subfolder"
+ with self.assertRaises(OSError):
+ _ = BertModel.from_pretrained(model_id)
+
+ model = BertModel.from_pretrained(model_id, subfolder=subfolder)
+
+ self.assertIsNotNone(model)
+
+ def test_model_from_pretrained_hub_subfolder_sharded(self):
+ subfolder = "bert"
+ model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
+ with self.assertRaises(OSError):
+ _ = BertModel.from_pretrained(model_id)
+
+ model = BertModel.from_pretrained(model_id, subfolder=subfolder)
+
+ self.assertIsNotNone(model)
+
def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model)
@@ -2382,6 +2677,10 @@ def test_model_from_pretrained_torch_dtype(self):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
+ # test model whose first param is not of a floating type, but int
+ model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
+ self.assertEqual(model.dtype, torch.float32)
+
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
@@ -2511,6 +2810,7 @@ def test_checkpoint_sharding_from_hub(self):
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
+ @require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
@@ -2523,6 +2823,7 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self):
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
@require_usr_bin_time
+ @require_accelerate
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
@@ -2561,18 +2862,77 @@ def test_from_pretrained_low_cpu_mem_usage_measured(self):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
+ @require_accelerate
+ @require_torch_multi_gpu
+ @slow
+ def test_model_parallelism_gpt2(self):
+ device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
+ for i in range(12):
+ device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
+
+ model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
+
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ inputs = tokenizer("Hello, my name is", return_tensors="pt")
+ output = model.generate(inputs["input_ids"].to(0))
+
+ text_output = tokenizer.decode(output[0].tolist())
+ self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
+
+ @require_accelerate
+ @require_torch_gpu
+ def test_from_pretrained_disk_offload_task_model(self):
+ model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ device_map = {
+ "transformer.wte": 0,
+ "transformer.wpe": 0,
+ "transformer.h.0": "cpu",
+ "transformer.h.1": "cpu",
+ "transformer.h.2": "cpu",
+ "transformer.h.3": "disk",
+ "transformer.h.4": "disk",
+ "transformer.ln_f": 0,
+ "lm_head": 0,
+ }
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ inputs = torch.tensor([[1, 2, 3]]).to(0)
+
+ model.save_pretrained(tmp_dir)
+ new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
+ outputs1 = new_model.to(0)(inputs)
+
+ offload_folder = os.path.join(tmp_dir, "offload")
+ new_model_with_offload = AutoModelForCausalLM.from_pretrained(
+ tmp_dir, device_map=device_map, offload_folder=offload_folder
+ )
+ outputs2 = new_model_with_offload(inputs)
+
+ self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
+
+ # With state dict temp offload
+ offload_folder = os.path.join(tmp_dir, "offload")
+ new_model_with_offload = AutoModelForCausalLM.from_pretrained(
+ tmp_dir,
+ device_map=device_map,
+ offload_folder=offload_folder,
+ offload_state_dict=True,
+ )
+ outputs2 = new_model_with_offload(inputs)
+
+ self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
+
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@@ -2583,27 +2943,24 @@ def test_cached_files_are_used_when_internet_is_down(self):
class ModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-model")
- except HTTPError:
- pass
-
- try:
- delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="test-model")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-model")
+ delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-model-config")
+ delete_repo(token=cls._token, repo_id="test-dynamic-model")
except HTTPError:
pass
@@ -2612,29 +2969,46 @@ def test_push_to_hub(self):
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
+ model.push_to_hub("test-model", use_auth_token=self._token)
+
+ new_model = BertModel.from_pretrained(f"{USER}/test-model")
+ for p1, p2 in zip(model.parameters(), new_model.parameters()):
+ self.assertTrue(torch.equal(p1, p2))
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-model")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(os.path.join(tmp_dir, "test-model"), push_to_hub=True, use_auth_token=self._token)
+ model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, use_auth_token=self._token)
- new_model = BertModel.from_pretrained(f"{USER}/test-model")
- for p1, p2 in zip(model.parameters(), new_model.parameters()):
- self.assertTrue(torch.equal(p1, p2))
+ new_model = BertModel.from_pretrained(f"{USER}/test-model")
+ for p1, p2 in zip(model.parameters(), new_model.parameters()):
+ self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
+ model.push_to_hub("valid_org/test-model-org", use_auth_token=self._token)
+
+ new_model = BertModel.from_pretrained("valid_org/test-model-org")
+ for p1, p2 in zip(model.parameters(), new_model.parameters()):
+ self.assertTrue(torch.equal(p1, p2))
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-model-org")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
- os.path.join(tmp_dir, "test-model-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-org"
)
- new_model = BertModel.from_pretrained("valid_org/test-model-org")
- for p1, p2 in zip(model.parameters(), new_model.parameters()):
- self.assertTrue(torch.equal(p1, p2))
+ new_model = BertModel.from_pretrained("valid_org/test-model-org")
+ for p1, p2 in zip(model.parameters(), new_model.parameters()):
+ self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_dynamic_model(self):
CustomConfig.register_for_auto_class()
@@ -2643,16 +3017,12 @@ def test_push_to_hub_dynamic_model(self):
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model", use_auth_token=self._token)
- model.save_pretrained(tmp_dir)
- # checks
- self.assertDictEqual(
- config.auto_map,
- {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
- )
-
- repo.push_to_hub()
+ model.push_to_hub("test-dynamic-model", use_auth_token=self._token)
+ # checks
+ self.assertDictEqual(
+ config.auto_map,
+ {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
+ )
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py
index b4238facc1ae..e22c7e6705b3 100644
--- a/tests/test_modeling_flax_common.py
+++ b/tests/test_modeling_flax_common.py
@@ -14,6 +14,7 @@
import copy
import inspect
+import json
import random
import tempfile
import unittest
@@ -22,12 +23,12 @@
import numpy as np
import transformers
-from huggingface_hub import delete_repo, login
+from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
- PASS,
+ TOKEN,
USER,
CaptureLogger,
is_pt_flax_cross_test,
@@ -36,6 +37,7 @@
torch_device,
)
from transformers.utils import logging
+from transformers.utils.generic import ModelOutput
if is_flax_available():
@@ -44,6 +46,7 @@
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+ from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -57,6 +60,7 @@
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
+ from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -169,8 +173,8 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- # (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
- def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
+ # (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@@ -180,24 +184,71 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
- if type(fx_outputs) in [tuple, list]:
- self.assertEqual(type(fx_outputs), type(pt_outputs))
- self.assertEqual(len(fx_outputs), len(pt_outputs))
- if type(names) == tuple:
- for fo, po, name in zip(fx_outputs, pt_outputs, names):
- self.check_outputs(fo, po, model_class, names=name)
- elif type(names) == str:
- for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
- self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
+
+ self.assertEqual(type(name), str)
+ if attributes is not None:
+ self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
+
+ # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
+ if isinstance(fx_outputs, ModelOutput):
+ self.assertTrue(
+ isinstance(pt_outputs, ModelOutput),
+ f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
+ )
+
+ fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
+ pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
+
+ self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
+
+ # convert to the case of `tuple`
+ # appending each key to the current (string) `name`
+ attributes = tuple([f"{name}.{k}" for k in fx_keys])
+ self.check_pt_flax_outputs(
+ fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
+ )
+
+ # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
+ elif type(fx_outputs) in [tuple, list]:
+ self.assertEqual(
+ type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
+ )
+ self.assertEqual(
+ len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
+ )
+
+ if attributes is not None:
+ # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
+ self.assertEqual(
+ len(attributes),
+ len(fx_outputs),
+ f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
+ )
else:
- raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
+ # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
+ attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
+
+ for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
+ self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
+
elif isinstance(fx_outputs, jnp.ndarray):
- self.assertTrue(isinstance(pt_outputs, torch.Tensor))
+ self.assertTrue(
+ isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
+ )
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
+ self.assertEqual(
+ fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
+ )
+
+ # deal with NumPy's scalars to make replacing nan values by 0 work.
+ if np.isscalar(fx_outputs):
+ fx_outputs = np.array([fx_outputs])
+ pt_outputs = np.array([pt_outputs])
+
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
@@ -206,10 +257,14 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
- self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
+ max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
+ self.assertLessEqual(
+ max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
+ )
else:
raise ValueError(
- f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
+ "`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
+ f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
@@ -253,7 +308,7 @@ def test_equivalence_pt_to_flax(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@@ -265,7 +320,7 @@ def test_equivalence_pt_to_flax(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@@ -308,7 +363,7 @@ def test_equivalence_flax_to_pt(self):
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@@ -325,7 +380,7 @@ def test_equivalence_flax_to_pt(self):
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
- self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
+ self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -991,23 +1046,105 @@ def _assert_all_params_initialised(model, params):
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
+ def test_checkpoint_sharding_from_hub(self):
+ model = FlaxBertModel.from_pretrained("ArthurZ/flax-tiny-random-bert-sharded")
+ # the model above is the same as the model below, just a sharded version.
+ ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
+ for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
+ assert np.allclose(np.array(p1), np.array(p2))
+
+ def test_checkpoint_sharding_local(self):
+ model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
+ for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
+ model.save_pretrained(tmp_dir, max_shard_size=max_size)
+
+ # Get each shard file and its size
+ shard_to_size = {}
+ for shard in os.listdir(tmp_dir):
+ if shard.endswith(".msgpack"):
+ shard_file = os.path.join(tmp_dir, shard)
+ shard_to_size[shard_file] = os.path.getsize(shard_file)
+
+ index_file = os.path.join(tmp_dir, FLAX_WEIGHTS_INDEX_NAME)
+ # Check there is an index but no regular weight file
+ self.assertTrue(os.path.isfile(index_file))
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
+
+ # Check a file is bigger than max_size only when it has a single weight
+ for shard_file, size in shard_to_size.items():
+ if max_size.endswith("kiB"):
+ max_size_int = int(max_size[:-3]) * 2**10
+ else:
+ max_size_int = int(max_size[:-2]) * 10**3
+ # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
+ # the size asked for (since we count parameters)
+ if size >= max_size_int + 50000:
+ with open(shard_file, "rb") as state_f:
+ state_file = from_bytes(FlaxBertModel, state_f.read())
+ self.assertEqual(len(state_file), 1)
+
+ # Check the index and the shard files found match
+ with open(index_file, "r", encoding="utf-8") as f:
+ index = json.loads(f.read())
+
+ all_shards = set(index["weight_map"].values())
+ shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".msgpack"))
+ self.assertSetEqual(all_shards, shards_found)
+
+ # Finally, check the model can be reloaded
+ new_model = FlaxBertModel.from_pretrained(tmp_dir)
+ for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
+ self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
+
+ def test_gradient_checkpointing(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ # prepare inputs
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+ remat_model = model_class(config)
+
+ try:
+ remat_model.enable_gradient_checkpointing()
+ except NotImplementedError:
+ continue
+
+ outputs = model(**prepared_inputs_dict)
+ remat_outputs = remat_model(**prepared_inputs_dict)
+
+ # ensure that the dicts of outputs contain the same keys
+ self.assertEqual(outputs.keys(), remat_outputs.keys())
+
+ outputs = outputs.to_tuple()
+ remat_outputs = remat_outputs.to_tuple()
+
+ # ensure that the outputs remain precisely equal
+ for output, remat_output in zip(outputs, remat_outputs):
+ self.assertTrue((output == remat_output).all())
+
@require_flax
@is_staging_test
class FlaxModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-model-flax")
+ delete_repo(token=cls._token, repo_id="test-model-flax")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
except HTTPError:
pass
@@ -1016,38 +1153,63 @@ def test_push_to_hub(self):
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
+ model.push_to_hub("test-model-flax", use_auth_token=self._token)
+
+ new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
+
+ base_params = flatten_dict(unfreeze(model.params))
+ new_params = flatten_dict(unfreeze(new_model.params))
+
+ for key in base_params.keys():
+ max_diff = (base_params[key] - new_params[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-model-flax")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(
- os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
- )
+ model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, use_auth_token=self._token)
- new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
+ new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
- base_params = flatten_dict(unfreeze(model.params))
- new_params = flatten_dict(unfreeze(new_model.params))
+ base_params = flatten_dict(unfreeze(model.params))
+ new_params = flatten_dict(unfreeze(new_model.params))
- for key in base_params.keys():
- max_diff = (base_params[key] - new_params[key]).sum().item()
- self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+ for key in base_params.keys():
+ max_diff = (base_params[key] - new_params[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
+ model.push_to_hub("valid_org/test-model-flax-org", use_auth_token=self._token)
+
+ new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
+
+ base_params = flatten_dict(unfreeze(model.params))
+ new_params = flatten_dict(unfreeze(new_model.params))
+
+ for key in base_params.keys():
+ max_diff = (base_params[key] - new_params[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
- os.path.join(tmp_dir, "test-model-flax-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, use_auth_token=self._token
)
- new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
+ new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
- base_params = flatten_dict(unfreeze(model.params))
- new_params = flatten_dict(unfreeze(new_model.params))
+ base_params = flatten_dict(unfreeze(model.params))
+ new_params = flatten_dict(unfreeze(new_model.params))
- for key in base_params.keys():
- max_diff = (base_params[key] - new_params[key]).sum().item()
- self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+ for key in base_params.keys():
+ max_diff = (base_params[key] - new_params[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index 0d38713e08d3..abf26af2b651 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -23,24 +23,28 @@
import unittest
import unittest.mock as mock
from importlib import import_module
+from math import isnan
from typing import List, Tuple
-from huggingface_hub import delete_repo, login
+from datasets import Dataset
+
+from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import get_values
-from transformers.testing_utils import tooslow # noqa: F401
-from transformers.testing_utils import (
- PASS,
+from transformers.testing_utils import ( # noqa: F401
+ TOKEN,
USER,
CaptureLogger,
+ CaptureStdout,
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_tf,
require_tf2onnx,
slow,
+ tooslow,
torch_device,
)
from transformers.utils import logging
@@ -51,23 +55,27 @@
if is_tf_available():
+ import h5py
import numpy as np
import tensorflow as tf
from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
+ TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
@@ -83,7 +91,12 @@
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
- from transformers.modeling_tf_utils import unpack_inputs
+ from transformers.modeling_tf_utils import (
+ TF2_WEIGHTS_INDEX_NAME,
+ TF2_WEIGHTS_NAME,
+ tf_shard_checkpoint,
+ unpack_inputs,
+ )
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
@@ -159,6 +172,15 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
+ elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
+ num_patches = self.model_tester.image_size // self.model_tester.patch_size
+ inputs_dict["bool_masked_pos"] = tf.zeros(
+ (self.model_tester.batch_size, num_patches**2), dtype=tf.int32
+ )
+ elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
+ batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
+ inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)
+
return inputs_dict
def test_initialization(self):
@@ -196,6 +218,47 @@ def test_save_load_config(self):
self.assert_outputs_same(after_outputs, outputs)
+ @slow
+ def test_saved_model_creation(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = False
+ config.output_attentions = False
+
+ if hasattr(config, "use_cache"):
+ config.use_cache = False
+
+ model_class = self.all_model_classes[0]
+
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+
+ model(class_inputs_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname, saved_model=True)
+ saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
+ self.assertTrue(os.path.exists(saved_model_dir))
+
+ def test_prepare_serving_output(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = self.has_attentions
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+ outputs = model(inputs)
+ serving_outputs = model.serving_output(outputs)
+
+ for k, v in serving_outputs.items():
+ # Check that we have one of three possible outputs: None, tuple of tensors or a tensor
+ if isinstance(v, tuple):
+ self.assertTrue(all(isinstance(elem, tf.Tensor) for elem in v))
+ elif v is not None:
+ self.assertIsInstance(v, tf.Tensor)
+ else:
+ self.assertIsNone(v)
+
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -212,10 +275,10 @@ def test_forward_signature(self):
"decoder_input_ids",
"decoder_attention_mask",
]
+ expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else [])
expected_arg_names.extend(
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
)
- # Necessary to handle BART with newly added cross_attn_head_mask
expected_arg_names.extend(
["cross_attn_head_mask", "encoder_outputs"]
if "cross_attn_head_mask" in arg_names
@@ -419,7 +482,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
return new_tf_outputs, new_pt_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
- """Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
+ """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
@@ -445,8 +508,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
# TODO: remove this method and this line after issues are fixed
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
- tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
- pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
+ tf_keys = [k for k, v in tf_outputs.items() if v is not None]
+ pt_keys = [k for k, v in pt_outputs.items() if v is not None]
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
@@ -505,7 +568,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
else:
raise ValueError(
- f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
+ "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
)
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
@@ -956,7 +1020,10 @@ def recursive_check(tuple_object, dict_object):
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
- msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -972,9 +1039,10 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class)
- check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
# Not all models accept "labels" in the forward pass (yet :) )
if "labels" in inspect.signature(model.call).parameters.keys():
@@ -986,15 +1054,16 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- check_equivalence(
- model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
- )
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(
+ model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
+ )
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1270,12 +1339,7 @@ def test_loss_computation(self):
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
- loss_size = tf.size(added_label)
-
- if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
- # if loss is causal lm loss, labels are shift, so that one label per batch
- # is cut
- loss_size = loss_size - self.model_tester.batch_size
+ expected_loss_size = added_label.shape.as_list()[:1]
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -1284,12 +1348,26 @@ def test_loss_computation(self):
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
+
+ # Test that model correctly compute the loss when we mask some positions
+ prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
+ possible_input_names = {"input_ids", "pixel_values", "input_features"}
+ input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
+ model_input = prepared_for_class.pop(input_name)
+ if "labels" in prepared_for_class:
+ labels = prepared_for_class["labels"].numpy()
+ if len(labels.shape) > 1 and labels.shape[1] != 1:
+ labels[0] = -100
+ prepared_for_class["labels"] = tf.convert_to_tensor(labels)
+ loss = model(model_input, **prepared_for_class)[0]
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
+ self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -1320,7 +1398,10 @@ def test_loss_computation(self):
# Send to model
loss = model(tuple_input[:-1])[0]
- self.assertEqual(loss.shape, [loss_size])
+ self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
+
+ def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
+ self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1351,7 +1432,29 @@ def test_keras_fit(self):
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
- model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
+ accuracy_classes = [
+ "ForPreTraining",
+ "ForCausalLM",
+ "ForMaskedLM",
+ "ForQuestionAnswering",
+ "ForMultipleChoice",
+ "ForSequenceClassification",
+ "ForTokenClassification",
+ "ForNextSentencePrediction",
+ "LMHeadModel",
+ ]
+ for accuracy_class in accuracy_classes:
+ if model.__class__.__name__.endswith(accuracy_class):
+ metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
+ break
+ else:
+ metrics = []
+
+ model(model.dummy_inputs) # Build the model so we can get some constant weights
+ model_weights = model.get_weights()
+
+ # Run eagerly to save some expensive compilation times
+ model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
@@ -1361,6 +1464,13 @@ def test_keras_fit(self):
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
+ self.assertTrue(not isnan(val_loss1))
+ accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
+
+ # We reinitialize the model here even though our learning rate was zero
+ # because BatchNorm updates weights by means other than gradient descent.
+ model.set_weights(model_weights)
+
history2 = model.fit(
inputs_minus_labels,
labels,
@@ -1370,7 +1480,58 @@ def test_keras_fit(self):
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
- self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
+ self.assertTrue(not isnan(val_loss2))
+ accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
+ self.check_keras_fit_results(val_loss1, val_loss2)
+ self.assertEqual(history1.history.keys(), history2.history.keys())
+ for key in history1.history.keys():
+ if not key.startswith("val_"):
+ self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
+ if metrics:
+ self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
+
+ # Make sure fit works with tf.data.Dataset and results are consistent
+ dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
+ # Pass in all samples as a batch to match other `fit` calls
+ dataset = dataset.batch(len(dataset))
+
+ # Reinitialize to fix batchnorm again
+ model.set_weights(model_weights)
+
+ history3 = model.fit(
+ dataset,
+ validation_data=dataset,
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ )
+ val_loss3 = history3.history["val_loss"][0]
+ self.assertTrue(not isnan(val_loss3))
+ accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
+ self.check_keras_fit_results(val_loss1, val_loss3)
+ self.assertEqual(history1.history.keys(), history3.history.keys())
+ if metrics:
+ self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
+
+ def test_int64_inputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ prepared_for_class = self._prepare_for_class(
+ inputs_dict.copy(),
+ model_class,
+ return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
+ )
+ if not any(
+ [tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)]
+ ):
+ return # No integer inputs means no need for this test
+
+ prepared_for_class = {
+ key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
+ for key, tensor in prepared_for_class.items()
+ }
+ model = model_class(config)
+ model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
@@ -1459,6 +1620,131 @@ def test_model_main_input_name(self):
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
+ def test_dataset_conversion(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
+ tf_inputs_dict = {
+ key: val
+ for key, val in tf_inputs_dict.items()
+ if "head_mask" not in key and isinstance(val, tf.Tensor)
+ }
+ tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
+ input_dataset = Dataset.from_dict(tf_inputs_dict)
+ tf_dataset = model.prepare_tf_dataset(
+ input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
+ )
+ test_batch = next(iter(tf_dataset))
+ if isinstance(test_batch, tf.Tensor):
+ self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
+ else:
+ # Assert we discarded the unwanted extra column but kept everything else
+ self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
+ self.assertNotIn("extra_unwanted_column", test_batch)
+ for tensor in test_batch.values():
+ self.assertTrue(isinstance(tensor, tf.Tensor))
+ self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
+ model(test_batch, training=False)
+
+ if "labels" in inspect.signature(model_class.call).parameters.keys():
+ tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ if "labels" not in tf_inputs_dict:
+ return # This model isn't giving us labels after all, don't try training with it
+ tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
+ tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
+ input_dataset = Dataset.from_dict(tf_inputs_dict)
+ tf_dataset = model.prepare_tf_dataset(
+ input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
+ )
+ test_batch, test_batch_labels = next(iter(tf_dataset))
+ self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
+ feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
+ label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
+ # Assert we discarded the unwanted extra column but kept everything else
+ self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
+ if isinstance(test_batch, dict):
+ self.assertNotIn("extra_unwanted_column", test_batch)
+ if isinstance(test_batch_labels, dict):
+ self.assertNotIn("extra_unwanted_column", test_batch_labels)
+ model.compile(optimizer="sgd", run_eagerly=True)
+ model.train_on_batch(test_batch, test_batch_labels)
+
+ def _test_xla_generate(self, num_beams, num_return_sequences, max_length):
+ def _generate_and_check_results(model, config, inputs_dict):
+ if "input_ids" in inputs_dict:
+ inputs = inputs_dict["input_ids"]
+ # make sure there are no pad tokens in prompt, which may trigger unwanted behavior
+ if config.pad_token_id is not None:
+ if config.pad_token_id == 0:
+ new_pad_token = config.pad_token_id + 1
+ else:
+ new_pad_token = config.pad_token_id - 1
+ else:
+ new_pad_token = None
+ inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
+ elif "input_features" in inputs_dict:
+ inputs = inputs_dict["input_features"]
+ else:
+ raise ValueError("No valid generate input found in inputs_dict")
+
+ generated = model.generate(inputs).numpy()
+ generate_xla = tf.function(model.generate, jit_compile=True)
+ generated_xla = generate_xla(inputs).numpy()
+ self.assertListEqual(generated.tolist(), generated_xla.tolist())
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.eos_token_id = None # Generate until max length
+ config.max_length = max_length
+ config.do_sample = False
+ config.num_beams = num_beams
+ config.num_return_sequences = num_return_sequences
+
+ # fix config for models with additional sequence-length limiting settings
+ for var_name in ["max_position_embeddings", "max_target_positions"]:
+ if hasattr(config, var_name):
+ try:
+ setattr(config, var_name, max_length)
+ except NotImplementedError:
+ # xlnet will raise an exception when trying to set
+ # max_position_embeddings.
+ pass
+
+ model = model_class(config)
+
+ if model.supports_xla_generation:
+ _generate_and_check_results(model, config, inputs_dict)
+ else:
+ with self.assertRaises(ValueError):
+ _generate_and_check_results(model, config, inputs_dict)
+
+ def test_xla_generate_fast(self):
+ """
+ Basic quick test for generate-compatible classes that confirms that XLA-generated tokens are the same as their
+ non XLA counterparts.
+
+ Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
+ """
+ num_beams = 1
+ num_return_sequences = 1
+ max_length = 10
+ self._test_xla_generate(num_beams, num_return_sequences, max_length)
+
+ @slow
+ def test_xla_generate_slow(self):
+ """
+ Slow and challenging version of `test_xla_generate_fast` -- this test asks for several long sequences using
+ beam search, with and without XLA. The two outputs should match, and a failure in this test indicates that the
+ model may need further analysis if it is to be used for XLA generation.
+
+ Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
+ """
+ num_beams = 8
+ num_return_sequences = 2
+ max_length = 128
+ self._test_xla_generate(num_beams, num_return_sequences, max_length)
+
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []
@@ -1636,14 +1922,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@@ -1654,6 +1940,7 @@ class DummyModel:
def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs)
+ self.main_input_name = "input_ids"
@unpack_inputs
def call(
@@ -1661,9 +1948,14 @@ def call(
):
return input_ids, past, output_attentions, output_hidden_states, return_dict
+ @unpack_inputs
+ def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
+ return pixel_values, output_attentions, output_hidden_states, return_dict
+
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])
+ pixel_values = tf.constant([8, 9, 10, 11])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)
@@ -1710,6 +2002,14 @@ def call(
self.assertFalse(output[3])
self.assertFalse(output[4])
+ # test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
+ # decorated function as its main input.
+ output = dummy_model.foo(pixel_values=pixel_values)
+ tf.debugging.assert_equal(output[0], pixel_values)
+ self.assertFalse(output[1])
+ self.assertFalse(output[2])
+ self.assertFalse(output[3])
+
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
@@ -1745,23 +2045,198 @@ def masked_softmax(x, boolean_mask):
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
+ def test_checkpoint_sharding_from_hub(self):
+ model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
+ # the model above is the same as the model below, just a sharded version.
+ ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
+ for p1, p2 in zip(model.weights, ref_model.weights):
+ assert np.allclose(p1.numpy(), p2.numpy())
+
+ @is_pt_tf_cross_test
+ def test_checkpoint_sharding_local_from_pt(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
+ model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
+ # the model above is the same as the model below, just a sharded pytorch version.
+ ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
+ for p1, p2 in zip(model.weights, ref_model.weights):
+ assert np.allclose(p1.numpy(), p2.numpy())
+
+ def test_shard_checkpoint(self):
+ # This is the model we will use, total size 340,000 bytes.
+ model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Dense(200, use_bias=False), # size 80,000
+ tf.keras.layers.Dense(200, use_bias=False), # size 160,000
+ tf.keras.layers.Dense(100, use_bias=False), # size 80,000
+ tf.keras.layers.Dense(50, use_bias=False), # size 20,000
+ ]
+ )
+ inputs = tf.zeros((1, 100), dtype=tf.float32)
+ model(inputs)
+ weights = model.weights
+ weights_dict = {w.name: w for w in weights}
+ with self.subTest("No shard when max size is bigger than model size"):
+ shards, index = tf_shard_checkpoint(weights)
+ self.assertIsNone(index)
+ self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
+
+ with self.subTest("Test sharding, no weights bigger than max size"):
+ shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
+ # Split is first two layers then last two.
+ self.assertDictEqual(
+ index,
+ {
+ "metadata": {"total_size": 340000},
+ "weight_map": {
+ "dense/kernel:0": "tf_model-00001-of-00002.h5",
+ "dense_1/kernel:0": "tf_model-00001-of-00002.h5",
+ "dense_2/kernel:0": "tf_model-00002-of-00002.h5",
+ "dense_3/kernel:0": "tf_model-00002-of-00002.h5",
+ },
+ },
+ )
+
+ shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
+ shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
+ self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
+
+ with self.subTest("Test sharding with weights bigger than max size"):
+ shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
+ # Split is first layer, second layer then last 2.
+ self.assertDictEqual(
+ index,
+ {
+ "metadata": {"total_size": 340000},
+ "weight_map": {
+ "dense/kernel:0": "tf_model-00001-of-00003.h5",
+ "dense_1/kernel:0": "tf_model-00002-of-00003.h5",
+ "dense_2/kernel:0": "tf_model-00003-of-00003.h5",
+ "dense_3/kernel:0": "tf_model-00003-of-00003.h5",
+ },
+ },
+ )
+
+ shard1 = [weights_dict["dense/kernel:0"]]
+ shard2 = [weights_dict["dense_1/kernel:0"]]
+ shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
+ self.assertDictEqual(
+ shards,
+ {
+ "tf_model-00001-of-00003.h5": shard1,
+ "tf_model-00002-of-00003.h5": shard2,
+ "tf_model-00003-of-00003.h5": shard3,
+ },
+ )
+
+ def test_checkpoint_sharding_local(self):
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
+ for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
+ model.save_pretrained(tmp_dir, max_shard_size=max_size)
+
+ # Get each shard file and its size
+ shard_to_size = {}
+ for shard in os.listdir(tmp_dir):
+ if shard.endswith(".h5"):
+ shard_file = os.path.join(tmp_dir, shard)
+ shard_to_size[shard_file] = os.path.getsize(shard_file)
+
+ index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
+ # Check there is an index but no regular weight file
+ self.assertTrue(os.path.isfile(index_file))
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
+
+ # Check a file is bigger than max_size only when it has a single weight
+ for shard_file, size in shard_to_size.items():
+ if max_size.endswith("kiB"):
+ max_size_int = int(max_size[:-3]) * 2**10
+ else:
+ max_size_int = int(max_size[:-2]) * 10**3
+ # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
+ # the size asked for (since we count parameters)
+ if size >= max_size_int + 50000:
+ with h5py.File(shard_file, "r") as state_file:
+ self.assertEqual(len(state_file), 1)
+
+ # Check the index and the shard files found match
+ with open(index_file, "r", encoding="utf-8") as f:
+ index = json.loads(f.read())
+
+ all_shards = set(index["weight_map"].values())
+ shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
+ self.assertSetEqual(all_shards, shards_found)
+
+ # Finally, check the model can be reloaded
+ new_model = TFBertModel.from_pretrained(tmp_dir)
+
+ model(model.dummy_inputs)
+ new_model(model.dummy_inputs)
+
+ for p1, p2 in zip(model.weights, new_model.weights):
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
+
+ def test_generate_tf_function_export(self):
+ test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
+ max_length = 8
+
+ class DummyModel(tf.Module):
+ def __init__(self, model):
+ super(DummyModel, self).__init__()
+ self.model = model
+
+ @tf.function(
+ input_signature=(
+ tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
+ tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
+ ),
+ jit_compile=True,
+ )
+ def serving(self, input_ids, attention_mask):
+ outputs = self.model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=max_length,
+ return_dict_in_generate=True,
+ )
+ return {"sequences": outputs["sequences"]}
+
+ dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
+ dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
+ dummy_model = DummyModel(model=test_model)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
+ serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
+ for batch_size in range(1, len(dummy_input_ids) + 1):
+ inputs = {
+ "input_ids": tf.constant(dummy_input_ids[:batch_size]),
+ "attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
+ }
+ tf_func_outputs = serving_func(**inputs)["sequences"]
+ tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
+ tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
+
@require_tf
@is_staging_test
class TFModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-model-tf")
+ delete_repo(token=cls._token, repo_id="test-model-tf")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
except HTTPError:
pass
@@ -1772,41 +2247,65 @@ def test_push_to_hub(self):
model = TFBertModel(config)
# Make sure model is properly initialized
_ = model(model.dummy_inputs)
- with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(os.path.join(tmp_dir, "test-model-tf"), push_to_hub=True, use_auth_token=self._token)
- new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
- models_equal = True
- for p1, p2 in zip(model.weights, new_model.weights):
- if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
- models_equal = False
- self.assertTrue(models_equal)
-
- def test_push_to_hub_with_model_card(self):
- config = BertConfig(
- vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
- )
- model = TFBertModel(config)
+ logging.set_verbosity_info()
+ logger = logging.get_logger("transformers.utils.hub")
+ with CaptureLogger(logger) as cl:
+ model.push_to_hub("test-model-tf", use_auth_token=self._token)
+ logging.set_verbosity_warning()
+ # Check the model card was created and uploaded.
+ self.assertIn("Uploading README.md to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
+
+ new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
+ models_equal = True
+ for p1, p2 in zip(model.weights, new_model.weights):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-model-tf")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
- model.push_to_hub(os.path.join(tmp_dir, "test-model-tf"))
- self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "test-model-card-tf", "README.md")))
+ model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, use_auth_token=self._token)
+
+ new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
+ models_equal = True
+ for p1, p2 in zip(model.weights, new_model.weights):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
+ # Make sure model is properly initialized
+ _ = model(model.dummy_inputs)
+
+ model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
+
+ new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
+ models_equal = True
+ for p1, p2 in zip(model.weights, new_model.weights):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
+
+ # Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
- os.path.join(tmp_dir, "test-model-tf-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-tf-org"
)
- new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
- models_equal = True
- for p1, p2 in zip(model.weights, new_model.weights):
- if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
- models_equal = False
- self.assertTrue(models_equal)
+ new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
+ models_equal = True
+ for p1, p2 in zip(model.weights, new_model.weights):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index fe16e5e1cd52..5941a5711899 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -30,7 +30,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, delete_repo, set_access_token
+from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import (
AlbertTokenizer,
@@ -49,8 +50,9 @@
is_torch_available,
)
from transformers.testing_utils import (
- PASS,
+ TOKEN,
USER,
+ check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
is_staging_test,
@@ -577,6 +579,25 @@ def test_tokenizers_common_ids_setters(self):
self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [token_to_test_setters])
self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [token_id_to_test_setters])
+ @parameterized.expand([(True,), (False,)])
+ def test_tokenizers_special_tokens_properties_unset(self, verbose):
+ tokenizers = self.get_tokenizers(verbose=verbose)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ attributes_list = [
+ "bos_token",
+ "eos_token",
+ "unk_token",
+ "sep_token",
+ "pad_token",
+ "cls_token",
+ "mask_token",
+ "additional_special_tokens",
+ ]
+ for attr in attributes_list:
+ setattr(tokenizer, attr, None)
+ self.assertIsNone(getattr(tokenizer, attr))
+
def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizers = self.get_tokenizers()
@@ -969,7 +990,9 @@ def test_maximum_encoding_length_single_input(self):
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
total_length = len(sequence)
- self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short")
+ self.assertGreater(
+ total_length, 4, "Issue with the testing sequence, please update it, it's too short"
+ )
# Test with max model input length
model_max_length = tokenizer.model_max_length
@@ -979,7 +1002,9 @@ def test_maximum_encoding_length_single_input(self):
sequence1 = tokenizer(seq_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
self.assertGreater(
- total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short"
+ total_length1,
+ model_max_length,
+ "Issue with the testing sequence, please update it, it's too short",
)
# Simple
@@ -1005,7 +1030,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1016,7 +1042,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1131,7 +1158,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1142,7 +1170,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -2401,13 +2430,15 @@ def test_prepare_seq2seq_batch(self):
# Longer text that will definitely require truncation.
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
- " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
+ " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
+ " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
+ " will only worsen the violence and misery for millions of people.",
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei "
- 'pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu '
- "vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al"
+ ' Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi'
+ " cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
try:
batch = tokenizer.prepare_seq2seq_batch(
@@ -3319,6 +3350,11 @@ def test_save_pretrained(self):
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
+ # make sure that all ".json" files are saved in the correct format
+ for file_path in tokenizer_r_files + tokenizer_p_files:
+ if os.path.exists(file_path) and file_path.endswith(".json"):
+ check_json_file_has_correct_format(file_path)
+
# Checks it save with the same files + the tokenizer.json file for the fast one
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
@@ -3658,11 +3694,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
@@ -3725,7 +3759,8 @@ def test_tokenizer_mismatch_warning(self):
finally:
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class"
+ " this function is called from."
)
)
@@ -3794,14 +3829,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@@ -3813,22 +3848,24 @@ class TokenizerPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
- delete_repo(token=cls._token, name="test-tokenizer")
+ delete_repo(token=cls._token, repo_id="test-tokenizer")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-dynamic-tokenizer")
+ delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
except HTTPError:
pass
@@ -3838,12 +3875,20 @@ def test_push_to_hub(self):
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
- tokenizer.save_pretrained(
- os.path.join(tmp_dir, "test-tokenizer"), push_to_hub=True, use_auth_token=self._token
- )
- new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
- self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
+ tokenizer.push_to_hub("test-tokenizer", use_auth_token=self._token)
+ new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="test-tokenizer")
+
+ # Push to hub via save_pretrained
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, use_auth_token=self._token)
+
+ new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -3851,15 +3896,22 @@ def test_push_to_hub_in_organization(self):
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
+
+ tokenizer.push_to_hub("valid_org/test-tokenizer-org", use_auth_token=self._token)
+ new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
+
+ # Reset repo
+ delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
+
+ # Push to hub via save_pretrained
+ with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(
- os.path.join(tmp_dir, "test-tokenizer-org"),
- push_to_hub=True,
- use_auth_token=self._token,
- organization="valid_org",
+ tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, use_auth_token=self._token
)
- new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
- self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
+ new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self):
@@ -3871,17 +3923,7 @@ def test_push_to_hub_dynamic_tokenizer(self):
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
- tokenizer.save_pretrained(tmp_dir)
-
- with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
- tokenizer_config = json.load(f)
- self.assertDictEqual(
- tokenizer_config["auto_map"], {"AutoTokenizer": ["custom_tokenization.CustomTokenizer", None]}
- )
-
- repo.push_to_hub()
+ tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
@@ -3898,23 +3940,7 @@ def test_push_to_hub_dynamic_tokenizer(self):
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
- tokenizer.save_pretrained(tmp_dir)
-
- with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
- tokenizer_config = json.load(f)
- self.assertDictEqual(
- tokenizer_config["auto_map"],
- {
- "AutoTokenizer": [
- "custom_tokenization.CustomTokenizer",
- "custom_tokenization_fast.CustomTokenizerFast",
- ]
- },
- )
-
- repo.push_to_hub()
+ tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
diff --git a/tests/tokenization/test_tokenization_fast.py b/tests/tokenization/test_tokenization_fast.py
index 9e5ad178e53a..da98d17d7722 100644
--- a/tests/tokenization/test_tokenization_fast.py
+++ b/tests/tokenization/test_tokenization_fast.py
@@ -39,6 +39,7 @@ def setUp(self):
self.test_rust_tokenizer = True
model_paths = ["robot-test/dummy-tokenizer-fast", "robot-test/dummy-tokenizer-wordlevel"]
+ self.bytelevel_bpe_model_name = "SaulLu/dummy-tokenizer-bytelevel-bpe"
# Inclusion of 2 tokenizers to test different types of models (Unigram and WordLevel for the moment)
self.tokenizers_list = [(PreTrainedTokenizerFast, model_path, {}) for model_path in model_paths]
@@ -99,6 +100,15 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
shutil.rmtree(self.tmpdirname)
self.tmpdirname = tmpdirname_orig
+ def test_training_new_tokenizer_with_bytelevel(self):
+ tokenizer = self.rust_tokenizer_class.from_pretrained(self.bytelevel_bpe_model_name)
+
+ toy_text_iterator = ("a" for _ in range(1000))
+ new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)
+
+ encoding_ids = new_tokenizer.encode("aš¤")
+ self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
+
@require_tokenizers
class TokenizerVersioningTest(unittest.TestCase):
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index f9df63c15e34..9cdb02468b30 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -21,6 +21,7 @@
import random
import re
import subprocess
+import sys
import tempfile
import time
import unittest
@@ -29,7 +30,7 @@
import numpy as np
-from huggingface_hub import Repository, delete_repo, login
+from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import (
@@ -42,24 +43,29 @@
)
from transformers.testing_utils import (
ENDPOINT_STAGING,
- PASS,
+ TOKEN,
USER,
CaptureLogger,
TestCasePlus,
get_gpu_count,
get_tests_dir,
is_staging_test,
+ require_intel_extension_for_pytorch,
require_optuna,
require_ray,
require_sentencepiece,
require_sigopt,
require_tokenizers,
require_torch,
- require_torch_bf16,
+ require_torch_bf16_cpu,
+ require_torch_bf16_gpu,
require_torch_gpu,
require_torch_multi_gpu,
+ require_torch_non_multi_gpu,
+ require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_gpus,
+ require_torchdynamo,
require_wandb,
slow,
)
@@ -550,7 +556,7 @@ def test_adafactor_lr_none(self):
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu
- @require_torch_bf16
+ @require_torch_bf16_gpu
def test_mixed_bf16(self):
# very basic test
@@ -637,6 +643,29 @@ def test_number_of_steps_in_training(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
+ @require_torch_bf16_cpu
+ @require_intel_extension_for_pytorch
+ def test_number_of_steps_in_training_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ # Regular training has n_epochs * len(train_dl) steps
+ trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
+
+ # Check passing num_train_epochs works (and a float version too):
+ trainer = get_regression_trainer(
+ learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
+
+ # If we pass a max_steps, num_train_epochs is ignored
+ trainer = get_regression_trainer(
+ learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, 10)
+
def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
@@ -817,6 +846,101 @@ def test_evaluate(self):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+ def test_evaluate_with_jit(self):
+ trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With logits preprocess
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ compute_metrics=AlmostAccuracy(),
+ preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
+ jit_mode_eval=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ @require_torch_bf16_cpu
+ @require_intel_extension_for_pytorch
+ def test_evaluate_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, no_cuda=True
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ use_ipex=True,
+ eval_len=66,
+ compute_metrics=AlmostAccuracy(),
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With logits preprocess
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ use_ipex=True,
+ compute_metrics=AlmostAccuracy(),
+ preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
def test_predict(self):
trainer = get_regression_trainer(a=1.5, b=2.5)
preds = trainer.predict(trainer.eval_dataset).predictions
@@ -849,6 +973,85 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+ def test_predict_with_jit(self):
+ trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With more than one output of the model
+ trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+
+ # With more than one output/label of the model
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"], jit_mode_eval=True
+ )
+ outputs = trainer.predict(trainer.eval_dataset)
+ preds = outputs.predictions
+ labels = outputs.label_ids
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+ self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
+ self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+
+ @require_torch_bf16_cpu
+ @require_intel_extension_for_pytorch
+ def test_predict_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With more than one output of the model
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+
+ # With more than one output/label of the model
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ double_output=True,
+ label_names=["labels", "labels_2"],
+ use_ipex=True,
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ outputs = trainer.predict(trainer.eval_dataset)
+ preds = outputs.predictions
+ labels = outputs.label_ids
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+ self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
+ self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+
def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)
@@ -1047,8 +1250,8 @@ def test_resume_training_with_randomness(self):
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
- self.assertAlmostEqual(a, a1, delta=1e-8)
- self.assertAlmostEqual(b, b1, delta=1e-8)
+ self.assertAlmostEqual(a, a1, delta=1e-5)
+ self.assertAlmostEqual(b, b1, delta=1e-5)
with self.subTest("Test every epoch"):
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
@@ -1072,8 +1275,43 @@ def test_resume_training_with_randomness(self):
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, checkpoint_dir))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
- self.assertAlmostEqual(a, a1, delta=1e-8)
- self.assertAlmostEqual(b, b1, delta=1e-8)
+ self.assertAlmostEqual(a, a1, delta=1e-5)
+ self.assertAlmostEqual(b, b1, delta=1e-5)
+
+ @slow
+ @require_torch_non_multi_gpu
+ def test_auto_batch_size_finder(self):
+
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = True
+
+ SRC_DIR = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
+ )
+ sys.path.append(SRC_DIR)
+ import run_glue
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ testargs = f"""
+ run_glue.py
+ --model_name_or_path distilbert-base-uncased
+ --task_name mrpc
+ --do_train
+ --do_eval
+ --max_seq_len 128
+ --per_device_train_batch_size 4096
+ --learning_rate 2e-5
+ --num_train_epochs 1
+ --output_dir {tmpdir}
+ --auto_find_batch_size 0
+ """.split()
+ with self.assertRaises(RuntimeError):
+ with patch.object(sys, "argv", testargs):
+ run_glue.main()
+
+ testargs[-1] = "1"
+ with patch.object(sys, "argv", testargs):
+ run_glue.main()
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
@@ -1292,7 +1530,8 @@ def test_trainer_eval_lm(self):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
- train_dataset = SampleIterableDataset()
+ # Adding one column not used by the model should have no impact
+ train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
@@ -1326,7 +1565,8 @@ def test_training_finite_iterable_dataset(self):
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
- eval_dataset = SampleIterableDataset()
+ # Adding one column not used by the model should have no impact
+ eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
@@ -1363,7 +1603,8 @@ def test_predict_iterable_dataset(self):
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
- test_dataset = SampleIterableDataset(length=66)
+ # Adding one column not used by the model should have no impact
+ test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
@@ -1511,7 +1752,7 @@ def test_fp16_full_eval(self):
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
- # 1. with mem metrics enabled
+ # 1. with fp16_full_eval disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
@@ -1532,7 +1773,7 @@ def test_fp16_full_eval(self):
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
- # 2. with mem metrics disabled
+ # 2. with fp16_full_eval enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, fp16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
fp16_init = metrics["init_mem_gpu_alloc_delta"]
@@ -1554,8 +1795,121 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
+ @require_torch_non_multi_gpu
+ @require_torchdynamo
+ @require_torch_tensorrt_fx
+ def test_torchdynamo_full_eval(self):
+ # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
+ n_gpus = get_gpu_count()
+
+ bs = 8
+ eval_len = 16 * n_gpus
+ # make the params are somewhat big so that there will be enough RAM consumed to be able to
+ # measure things. We should get about 64KB for a+b in fp32
+ a = torch.ones(1000, bs) + 0.001
+ b = torch.ones(1000, bs) - 0.001
+
+ # 1. Default - without TorchDynamo
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
+ metrics = trainer.evaluate()
+ original_eval_loss = metrics["eval_loss"]
+ del trainer
+
+ # 2. TorchDynamo eager
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
+ metrics = trainer.evaluate()
+ self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
+ del trainer
+
+ # 3. TorchDynamo nvfuser
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
+ metrics = trainer.evaluate()
+ self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
+
+ # 4. TorchDynamo fx2trt
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
+ metrics = trainer.evaluate()
+ t1 = metrics["eval_loss"]
+ t2 = original_eval_loss
+ self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
+
+ # 5. TorchDynamo fx2trt-fp16
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
+ metrics = trainer.evaluate()
+ t1 = metrics["eval_loss"]
+ t2 = original_eval_loss
+ # fp16 has accuracy accuracy degradation
+ self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
+
+ @require_torch_non_multi_gpu
+ @require_torchdynamo
+ def test_torchdynamo_memory(self):
+ # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
+ class CustomTrainer(Trainer):
+ def compute_loss(self, model, inputs, return_outputs=False):
+ x = inputs["x"]
+ output = model(x)
+ if self.args.n_gpu == 1:
+ return output.mean()
+ return output
+
+ class MyModule(torch.nn.Module):
+ """Simple module that does aggressive fusion"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ for _ in range(20):
+ x = torch.nn.functional.relu(x)
+ return x
+
+ mod = MyModule()
+
+ # 1. without TorchDynamo (eager baseline)
+ a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
+ a.grad = None
+ trainer = CustomTrainer(model=mod)
+ # warmup
+ for _ in range(10):
+ orig_loss = trainer.training_step(mod, {"x": a})
+
+ # resets
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ orig_loss = trainer.training_step(mod, {"x": a})
+ orig_peak_mem = torch.cuda.max_memory_allocated()
+ del trainer
+
+ # 2. TorchDynamo nvfuser
+ a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
+ a.grad = None
+ args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
+ trainer = CustomTrainer(model=mod, args=args)
+ # warmup
+ for _ in range(10):
+ loss = trainer.training_step(mod, {"x": a})
+
+ # resets
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ loss = trainer.training_step(mod, {"x": a})
+ peak_mem = torch.cuda.max_memory_allocated()
+ del trainer
+
+ # Functional check
+ self.assertAlmostEqual(loss, orig_loss)
+
+ # AOT Autograd recomputaion and nvfuser recomputation optimization
+ # aggressively fuses the operations and reduce the memory footprint.
+ self.assertGreater(orig_peak_mem, peak_mem * 2)
+
@require_torch_gpu
- @require_torch_bf16
+ @require_torch_bf16_gpu
def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval
@@ -1571,7 +1925,7 @@ def test_bf16_full_eval(self):
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
- # 1. with mem metrics enabled
+ # 1. with bf16_full_eval disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
@@ -1592,7 +1946,7 @@ def test_bf16_full_eval(self):
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
- # 2. with mem metrics disabled
+ # 2. with bf16_full_eval enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
bf16_init = metrics["init_mem_gpu_alloc_delta"]
@@ -1632,18 +1986,20 @@ def test_no_wd_param_group(self):
class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls._token = login(username=USER, password=PASS)
+ cls._token = TOKEN
+ set_access_token(TOKEN)
+ HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
try:
- delete_repo(token=cls._token, name=model)
+ delete_repo(token=cls._token, repo_id=model)
except HTTPError:
pass
try:
- delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
+ delete_repo(token=cls._token, repo_id="valid_org/test-trainer-org")
except HTTPError:
pass
@@ -1861,6 +2217,7 @@ def test_hyperparameter_search_ray_client(self):
self.ray_hyperparameter_search()
+@slow
@require_torch
@require_sigopt
class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py
index a7daee4fd08d..a88ca1cb0d49 100644
--- a/tests/trainer/test_trainer_callback.py
+++ b/tests/trainer/test_trainer_callback.py
@@ -66,6 +66,9 @@ def on_step_end(self, args, state, control, **kwargs):
def on_evaluate(self, args, state, control, **kwargs):
self.events.append("on_evaluate")
+ def on_predict(self, args, state, control, **kwargs):
+ self.events.append("on_predict")
+
def on_save(self, args, state, control, **kwargs):
self.events.append("on_save")
diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py
index 7710892d8d79..869d19b0a1e6 100644
--- a/tests/trainer/test_trainer_utils.py
+++ b/tests/trainer/test_trainer_utils.py
@@ -18,7 +18,9 @@
import numpy as np
-from transformers.testing_utils import require_torch
+from transformers.data.data_collator import default_data_collator
+from transformers.testing_utils import require_accelerate, require_torch
+from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
from transformers.utils import is_torch_available
@@ -39,6 +41,8 @@
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
+ numpy_pad_and_concatenate,
+ torch_pad_and_concatenate,
)
class TstLayer(nn.Module):
@@ -420,3 +424,76 @@ def test_shard_sampler(self):
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
+
+ @require_accelerate
+ def test_executable_batch_size(self):
+ batch_sizes = []
+
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
+ def mock_training_loop_function(batch_size):
+ nonlocal batch_sizes
+ batch_sizes.append(batch_size)
+ if batch_size > 16:
+ raise RuntimeError("CUDA out of memory.")
+
+ mock_training_loop_function()
+ self.assertEqual(batch_sizes, [64, 32, 16])
+
+ @require_accelerate
+ def test_executable_batch_size_no_search(self):
+ batch_sizes = []
+
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
+ def mock_training_loop_function(batch_size):
+ nonlocal batch_sizes
+ batch_sizes.append(batch_size)
+
+ mock_training_loop_function()
+ self.assertEqual(batch_sizes, [64])
+
+ @require_accelerate
+ def test_executable_batch_size_with_error(self):
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
+ def mock_training_loop_function(batch_size):
+ raise RuntimeError("CUDA out of memory.")
+
+ with self.assertRaises(RuntimeError) as cm:
+ mock_training_loop_function()
+ self.assertEqual("CUDA out of memory", cm.args[0])
+
+ def test_pad_and_concatenate_with_1d(self):
+ """Tests whether pad_and_concatenate works with scalars."""
+ array1 = 1.0
+ array2 = 2.0
+ result = numpy_pad_and_concatenate(array1, array2)
+ self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result))
+
+ tensor1 = torch.tensor(1.0)
+ tensor2 = torch.tensor(2.0)
+ result = torch_pad_and_concatenate(tensor1, tensor2)
+ self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0])))
+
+ def test_remove_columns_collator(self):
+ class MockLogger:
+ def __init__(self) -> None:
+ self.called = 0
+
+ def info(self, msg):
+ self.called += 1
+ self.last_msg = msg
+
+ data_batch = [
+ {"col1": 1, "col2": 2, "col3": 3},
+ {"col1": 1, "col2": 2, "col3": 3},
+ ]
+ logger = MockLogger()
+ remove_columns_collator = RemoveColumnsCollator(
+ default_data_collator, ["col1", "col2"], logger, "model", "training"
+ )
+
+ self.assertNotIn("col3", remove_columns_collator(data_batch))
+ # check that the logging message is printed out only once
+ remove_columns_collator(data_batch)
+ remove_columns_collator(data_batch)
+ self.assertEqual(logger.called, 1)
+ self.assertIn("col3", logger.last_msg)
diff --git a/tests/utils/test_cli.py b/tests/utils/test_cli.py
index 1e5ba4fa27c9..f39aa600679a 100644
--- a/tests/utils/test_cli.py
+++ b/tests/utils/test_cli.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+import shutil
import unittest
from unittest.mock import patch
-from transformers.testing_utils import CaptureStd
+from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test
class CLITest(unittest.TestCase):
@@ -30,3 +32,16 @@ def test_cli_env(self):
self.assertIn("Python version", cs.out)
self.assertIn("Platform", cs.out)
self.assertIn("Using distributed or parallel set-up in script?", cs.out)
+
+ @is_pt_tf_cross_test
+ @patch(
+ "sys.argv", ["fakeprogrampath", "pt-to-tf", "--model-name", "hf-internal-testing/tiny-random-gptj", "--no-pr"]
+ )
+ def test_cli_pt_to_tf(self):
+ import transformers.commands.transformers_cli
+
+ shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
+ transformers.commands.transformers_cli.main()
+
+ # The original repo has no TF weights -- if they exist, they were created by the CLI
+ self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
diff --git a/tests/utils/test_convert_slow_tokenizer.py b/tests/utils/test_convert_slow_tokenizer.py
index f7bb60acfdb0..8655ea4602e7 100644
--- a/tests/utils/test_convert_slow_tokenizer.py
+++ b/tests/utils/test_convert_slow_tokenizer.py
@@ -28,9 +28,7 @@ def test_spm_converter_bytefallback_warning(self):
_ = SpmConverter(original_tokenizer_with_bytefallback)
self.assertEqual(len(w), 1)
self.assertIn(
- (
- "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
- " which is not implemented in the fast tokenizers."
- ),
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers.",
str(w[0].message),
)
diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py
index 75c4f19caa1d..60676e9f7d9d 100644
--- a/tests/utils/test_file_utils.py
+++ b/tests/utils/test_file_utils.py
@@ -26,20 +26,13 @@
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.utils import (
- CONFIG_NAME,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ContextManagers,
- EntryNotFoundError,
- RepositoryNotFoundError,
- RevisionNotFoundError,
- filename_to_url,
find_labels,
get_file_from_repo,
- get_from_cache,
has_file,
- hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
@@ -85,52 +78,6 @@ def test_module_spec_available(self):
class GetFromCacheTests(unittest.TestCase):
- def test_bogus_url(self):
- # This lets us simulate no connection
- # as the error raised is the same
- # `ConnectionError`
- url = "https://bogus"
- with self.assertRaisesRegex(ValueError, "Connection error"):
- _ = get_from_cache(url)
-
- def test_file_not_found(self):
- # Valid revision (None) but missing file.
- url = hf_bucket_url(MODEL_ID, filename="missing.bin")
- with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
- _ = get_from_cache(url)
-
- def test_model_not_found(self):
- # Invalid model file.
- url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
- with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
- _ = get_from_cache(url)
-
- def test_revision_not_found(self):
- # Valid file but missing revision
- url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
- with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
- _ = get_from_cache(url)
-
- def test_standard_object(self):
- url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
- filepath = get_from_cache(url, force_download=True)
- metadata = filename_to_url(filepath)
- self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
-
- def test_standard_object_rev(self):
- # Same object, but different revision
- url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
- filepath = get_from_cache(url, force_download=True)
- metadata = filename_to_url(filepath)
- self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
- # Caution: check that the etag is *not* equal to the one from `test_standard_object`
-
- def test_lfs_object(self):
- url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
- filepath = get_from_cache(url, force_download=True)
- metadata = filename_to_url(filepath)
- self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
-
def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py
new file mode 100644
index 000000000000..6fbdbee40360
--- /dev/null
+++ b/tests/utils/test_generic.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.utils import flatten_dict
+
+
+class GenericTester(unittest.TestCase):
+ def test_flatten_dict(self):
+ input_dict = {
+ "task_specific_params": {
+ "summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
+ "summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
+ "summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
+ }
+ }
+ expected_dict = {
+ "task_specific_params.summarization.length_penalty": 1.0,
+ "task_specific_params.summarization.max_length": 128,
+ "task_specific_params.summarization.min_length": 12,
+ "task_specific_params.summarization.num_beams": 4,
+ "task_specific_params.summarization_cnn.length_penalty": 2.0,
+ "task_specific_params.summarization_cnn.max_length": 142,
+ "task_specific_params.summarization_cnn.min_length": 56,
+ "task_specific_params.summarization_cnn.num_beams": 4,
+ "task_specific_params.summarization_xsum.length_penalty": 1.0,
+ "task_specific_params.summarization_xsum.max_length": 62,
+ "task_specific_params.summarization_xsum.min_length": 11,
+ "task_specific_params.summarization_xsum.num_beams": 6,
+ }
+
+ self.assertEqual(flatten_dict(input_dict), expected_dict)
diff --git a/tests/utils/test_model_card.py b/tests/utils/test_model_card.py
index 1004642a92a2..7d0e8795e0aa 100644
--- a/tests/utils/test_model_card.py
+++ b/tests/utils/test_model_card.py
@@ -38,7 +38,10 @@ def setUp(self):
},
"training_data": {
"Dataset": "English Wikipedia dump dated 2018-12-01",
- "Preprocessing": "Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf",
+ "Preprocessing": (
+ "Using SentencePiece vocabulary of size 52k tokens. See details on"
+ " https://arxiv.org/pdf/1810.03993.pdf"
+ ),
},
"quantitative_analyses": {"BLEU": 55.1, "ROUGE-1": 76},
}
diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py
index 8edfc8eab02d..0863528708e3 100644
--- a/tests/utils/test_modeling_tf_core.py
+++ b/tests/utils/test_modeling_tf_core.py
@@ -18,6 +18,7 @@
import os
import tempfile
from importlib import import_module
+from math import isnan
from transformers import is_tf_available
from transformers.models.auto import get_values
@@ -135,25 +136,70 @@ def run_in_graph_mode():
self.assertIsNotNone(outputs)
@slow
- def test_saved_model_creation(self):
+ def test_xla_fit(self):
+ # This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.output_hidden_states = False
- config.output_attentions = False
-
- if hasattr(config, "use_cache"):
- config.use_cache = False
-
- model_class = self.all_model_classes[0]
-
- class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
- model = model_class(config)
-
- model(class_inputs_dict)
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ if getattr(model, "hf_compute_loss", None):
+ # Test that model correctly compute the loss with kwargs
+ prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
+ # Is there a better way to remove these decoder inputs?
+ prepared_for_class = {
+ key: val
+ for key, val in prepared_for_class.items()
+ if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
+ }
+
+ possible_label_cols = {
+ "labels",
+ "label",
+ "label_ids",
+ "start_positions",
+ "start_position",
+ "end_positions",
+ "end_position",
+ "next_sentence_label",
+ }
+ label_names = possible_label_cols.intersection(set(prepared_for_class))
+ self.assertGreater(len(label_names), 0, msg="No matching label names found!")
+ labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
+ inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
+ self.assertGreater(len(inputs_minus_labels), 0)
+
+ # Make sure it works with XLA!
+ model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
+ # Make sure the model fits without crashing regardless of where we pass the labels
+ history = model.fit(
+ prepared_for_class,
+ validation_data=prepared_for_class,
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ verbose=0,
+ )
+ loss = history.history["loss"][0]
+ self.assertTrue(not isnan(loss))
+ val_loss = history.history["val_loss"][0]
+ self.assertTrue(not isnan(val_loss))
+
+ # Now test it with separate labels, to make sure that path works in XLA too.
+ model = model_class(config)
+ model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
+ history = model.fit(
+ inputs_minus_labels,
+ labels,
+ validation_data=(inputs_minus_labels, labels),
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ verbose=0,
+ )
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname, saved_model=True)
- saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
- self.assertTrue(os.path.exists(saved_model_dir))
+ loss = history.history["loss"][0]
+ self.assertTrue(not isnan(loss))
+ val_loss = history.history["val_loss"][0]
+ self.assertTrue(not isnan(val_loss))
@slow
def test_saved_model_creation_extended(self):
@@ -205,18 +251,19 @@ def test_saved_model_creation_extended(self):
@slow
def test_mixed_precision(self):
- tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
- model = model_class(config)
- outputs = model(class_inputs_dict)
-
- self.assertIsNotNone(outputs)
+ tf.keras.mixed_precision.set_global_policy("mixed_float16")
+
+ # try/finally block to ensure subsequent tests run in float32
+ try:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+ outputs = model(class_inputs_dict)
- tf.keras.mixed_precision.experimental.set_policy("float32")
+ self.assertIsNotNone(outputs)
+ finally:
+ tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):
diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py
index 33f5d4bd0a8d..0636a4399e89 100644
--- a/tests/utils/test_offline.py
+++ b/tests/utils/test_offline.py
@@ -34,7 +34,7 @@ def test_offline_mode(self):
"""
run = """
-mname = "lysandre/tiny-bert-random"
+mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
@@ -69,3 +69,53 @@ def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())
+
+ @require_torch
+ def test_offline_mode_sharded_checkpoint(self):
+
+ # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
+ # `transformers` is loaded, and it's too late for inside pytest - so we are changing it
+ # while running an external program
+
+ # python one-liner segments
+
+ # this must be loaded before socket.socket is monkey-patched
+ load = """
+from transformers import BertConfig, BertModel, BertTokenizer
+ """
+
+ run = """
+mname = "hf-internal-testing/tiny-random-bert-sharded"
+BertConfig.from_pretrained(mname)
+BertModel.from_pretrained(mname)
+print("success")
+ """
+
+ mock = """
+import socket
+def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
+socket.socket = offline_socket
+ """
+
+ # baseline - just load from_pretrained with normal network
+ cmd = [sys.executable, "-c", "\n".join([load, run])]
+
+ # should succeed
+ env = self.get_env()
+ result = subprocess.run(cmd, env=env, check=False, capture_output=True)
+ self.assertEqual(result.returncode, 0, result.stderr)
+ self.assertIn("success", result.stdout.decode())
+
+ # next emulate no network
+ cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
+
+ # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
+ # env["TRANSFORMERS_OFFLINE"] = "0"
+ # result = subprocess.run(cmd, env=env, check=False, capture_output=True)
+ # self.assertEqual(result.returncode, 1, result.stderr)
+
+ # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
+ env["TRANSFORMERS_OFFLINE"] = "1"
+ result = subprocess.run(cmd, env=env, check=False, capture_output=True)
+ self.assertEqual(result.returncode, 0, result.stderr)
+ self.assertIn("success", result.stdout.decode())
diff --git a/tests/utils/test_utils_check_copies.py b/tests/utils/test_utils_check_copies.py
index 7c81df714cb9..57cecf6653ff 100644
--- a/tests/utils/test_utils_check_copies.py
+++ b/tests/utils/test_utils_check_copies.py
@@ -125,9 +125,48 @@ def test_is_copy_consistent(self):
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
- md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning."
- localized_md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
- converted_md_list_sample = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (ę„čŖ HuggingFace) 伓é论ę [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) ē± Victor Sanh, Lysandre Debut and Thomas Wolf ååøć The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (ę„čŖ Google Research/Stanford University) 伓é论ę [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) ē± Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning ååøć\n"
+ md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
+ " Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
+ " Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
+ " Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1."
+ " **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace),"
+ " released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
+ " lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same"
+ " method has been applied to compress GPT2 into"
+ " [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
+ " [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
+ " Multilingual BERT into"
+ " [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
+ " version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)**"
+ " (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders"
+ " as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang"
+ " Luong, Quoc V. Le, Christopher D. Manning."
+ )
+ localized_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
+ converted_md_list_sample = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n1."
+ " **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (ę„čŖ HuggingFace) 伓é论ę"
+ " [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
+ " lighter](https://arxiv.org/abs/1910.01108) ē± Victor Sanh, Lysandre Debut and Thomas Wolf ååøć The same"
+ " method has been applied to compress GPT2 into"
+ " [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
+ " [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
+ " Multilingual BERT into"
+ " [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
+ " version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (ę„čŖ"
+ " Google Research/Stanford University) 伓é论ę [ELECTRA: Pre-training text encoders as discriminators rather"
+ " than generators](https://arxiv.org/abs/2003.10555) ē± Kevin Clark, Minh-Thang Luong, Quoc V. Le,"
+ " Christopher D. Manning ååøć\n"
+ )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
@@ -143,9 +182,24 @@ def test_convert_to_localized_md(self):
# Check whether the number of models is equal to README.md after conversion.
self.assertTrue(num_models_equal)
- link_changed_md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
- link_unchanged_md_list = "1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
- converted_md_list_sample = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ link_changed_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
+ " Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
+ " Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
+ " Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
+ )
+ link_unchanged_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (ę„čŖ Google Research and"
+ " the Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
+ converted_md_list_sample = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓é论ę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
index 382f42bfe159..bcbbace39e0e 100644
--- a/utils/check_config_docstrings.py
+++ b/utils/check_config_docstrings.py
@@ -41,6 +41,8 @@
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"CLIPConfig",
+ "OwlViTConfig",
+ "GroupViTConfig",
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"RagConfig",
diff --git a/utils/check_copies.py b/utils/check_copies.py
index 5363fd1ff338..e2e0e1a53e43 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -15,6 +15,7 @@
import argparse
import glob
+import importlib.util
import os
import re
@@ -40,26 +41,47 @@
"README.md": {
"start_prompt": "š¤ Transformers currently provides the following architectures",
"end_prompt": "1. Want to contribute a new model?",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
"README_zh-hans.md": {
"start_prompt": "š¤ Transformers ē®åęÆęå¦äøēę¶ę",
"end_prompt": "1. ę³č¦č“”ē®ę°ē樔åļ¼",
- "format_model_list": "**[{title}]({model_link})** (ę„čŖ {paper_affiliations}) 伓é论ę {paper_title_link} ē± {paper_authors} ååøć{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (ę„čŖ {paper_affiliations}) 伓é论ę {paper_title_link} ē± {paper_authors}"
+ " ååøć{supplements}"
+ ),
},
"README_zh-hant.md": {
"start_prompt": "š¤ Transformers ē®åęÆę“仄äøēę¶ę§",
"end_prompt": "1. ę³č¦č²¢ē»ę°ē樔åļ¼",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
"README_ko.md": {
"start_prompt": "š¤ Transformersė ė¤ģ ėŖØėøė¤ģ ģ ź³µķ©ėė¤",
"end_prompt": "1. ģė”ģ“ ėŖØėøģ ģ¬ė¦¬ź³ ģ¶ėģ?",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
}
+# This is to make sure the transformers module imported is the one in the repo.
+spec = importlib.util.spec_from_file_location(
+ "transformers",
+ os.path.join(TRANSFORMERS_PATH, "__init__.py"),
+ submodule_search_locations=[TRANSFORMERS_PATH],
+)
+transformers_module = spec.loader.load_module()
+
+
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
@@ -111,6 +133,7 @@ def find_code_in_transformers(object_name):
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
+_re_fill_pattern = re.compile(r"]*>")
def get_indent(code):
@@ -130,7 +153,7 @@ def blackify(code):
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
- mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
+ mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True)
result = black.format_str(code, mode=mode)
result, _ = style_docstrings_in_code(result)
return result[len("class Bla:\n") :] if has_indent else result
@@ -300,8 +323,6 @@ def _rep(match):
# This regex is used to synchronize link.
_re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
- num_models_equal = True
-
if len(localized_model_list) == 0:
localized_model_index = {}
else:
@@ -313,13 +334,24 @@ def _rep(match):
except AttributeError:
raise AttributeError("A model name in localized READMEs cannot be recognized.")
+ model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
+
+ # We exclude keys in localized README not in the main one.
+ readmes_match = not any([k not in model_keys for k in localized_model_index])
+ localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
+
for model in model_list.strip().split("\n"):
title, model_link = _re_capture_title_link.search(model).groups()
if title not in localized_model_index:
- num_models_equal = False
+ readmes_match = False
# Add an anchor white space behind a model description string for regex.
# If metadata cannot be captured, the English version will be directly copied.
localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
+ elif _re_fill_pattern.search(localized_model_index[title]) is not None:
+ update = _re_capture_meta.sub(_rep, model + " ")
+ if update != localized_model_index[title]:
+ readmes_match = False
+ localized_model_index[title] = update
else:
# Synchronize link
localized_model_index[title] = _re_capture_title_link.sub(
@@ -328,7 +360,7 @@ def _rep(match):
sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
- return num_models_equal, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
+ return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
def convert_readme_to_index(model_list):
@@ -368,7 +400,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
- new_readme = readme.replace(
+ new_readme = new_readme.replace(
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
)
if new_readme != readme:
@@ -400,9 +432,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
_format_model_list = value["format_model_list"]
localized_md_list = get_model_list(filename, _start_prompt, _end_prompt)
- num_models_equal, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
+ readmes_match, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
- converted_md_lists.append((filename, num_models_equal, converted_md_list, _start_prompt, _end_prompt))
+ converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
converted_md_list = convert_readme_to_index(md_list)
if converted_md_list != index_list:
@@ -416,7 +448,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
for converted_md_list in converted_md_lists:
- filename, num_models_equal, converted_md, _start_prompt, _end_prompt = converted_md_list
+ filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
if filename == "README.md":
continue
@@ -426,17 +458,94 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
with open(os.path.join(REPO_PATH, filename), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [converted_md] + lines[end_index:])
- elif not num_models_equal:
+ elif not readmes_match:
raise ValueError(
f"The model list in the README changed and the list in `{filename}` has not been updated. Run "
"`make fix-copies` to fix this."
)
+SPECIAL_MODEL_NAMES = {
+ "Bert Generation": "BERT For Sequence Generation",
+ "BigBird": "BigBird-RoBERTa",
+ "Data2VecAudio": "Data2Vec",
+ "Data2VecText": "Data2Vec",
+ "Data2VecVision": "Data2Vec",
+ "Marian": "MarianMT",
+ "OpenAI GPT-2": "GPT-2",
+ "OpenAI GPT": "GPT",
+ "Perceiver": "Perceiver IO",
+ "ViT": "Vision Transformer (ViT)",
+}
+
+# Update this list with the models that shouldn't be in the README. This only concerns modular models or those who do
+# not have an associated paper.
+MODELS_NOT_IN_README = [
+ "BertJapanese",
+ "Encoder decoder",
+ "FairSeq Machine-Translation",
+ "HerBERT",
+ "RetriBERT",
+ "Speech Encoder decoder",
+ "Speech2Text",
+ "Speech2Text2",
+ "Vision Encoder decoder",
+ "VisionTextDualEncoder",
+]
+
+
+README_TEMPLATE = (
+ "1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
+ ") released with the paper []() by ."
+)
+
+
+def check_readme(overwrite=False):
+ info = LOCALIZED_READMES["README.md"]
+ models, start_index, end_index, lines = _find_text_in_file(
+ os.path.join(REPO_PATH, "README.md"),
+ info["start_prompt"],
+ info["end_prompt"],
+ )
+ models_in_readme = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in models.strip().split("\n")]
+
+ model_names_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
+ absents = [
+ (key, name)
+ for key, name in model_names_mapping.items()
+ if SPECIAL_MODEL_NAMES.get(name, name) not in models_in_readme
+ ]
+ # Remove exceptions
+ absents = [(key, name) for key, name in absents if name not in MODELS_NOT_IN_README]
+ if len(absents) > 0 and not overwrite:
+ print(absents)
+ raise ValueError(
+ "The main README doesn't contain all models, run `make fix-copies` to fill it with the missing model(s)"
+ " then complete the generated entries.\nIf the model is not supposed to be in the main README, add it to"
+ " the list `MODELS_NOT_IN_README` in utils/check_copies.py.\nIf it has a different name in the repo than"
+ " in the README, map the correspondence in `SPECIAL_MODEL_NAMES` in utils/check_copies.py."
+ )
+
+ new_models = [README_TEMPLATE.format(model_name=name, model_type=key) for key, name in absents]
+
+ all_models = models.strip().split("\n") + new_models
+ all_models = sorted(all_models, key=lambda x: re.search(r"\*\*\[([^\]]*)", x).groups()[0].lower())
+ all_models = "\n".join(all_models) + "\n"
+
+ if all_models != models:
+ if overwrite:
+ print("Fixing the main README.")
+ with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
+ f.writelines(lines[:start_index] + [all_models] + lines[end_index:])
+ else:
+ raise ValueError("The main README model list is not properly sorted. Run `make fix-copies` to fix this.")
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
+ check_readme(args.fix_and_overwrite)
check_copies(args.fix_and_overwrite)
check_full_copies(args.fix_and_overwrite)
diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py
new file mode 100644
index 000000000000..67ec2f94660a
--- /dev/null
+++ b/utils/check_doc_toc.py
@@ -0,0 +1,98 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from collections import defaultdict
+
+import yaml
+
+
+PATH_TO_TOC = "docs/source/en/_toctree.yml"
+
+
+def clean_model_doc_toc(model_doc):
+ """
+ Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically.
+ """
+ counts = defaultdict(int)
+ for doc in model_doc:
+ counts[doc["local"]] += 1
+ duplicates = [key for key, value in counts.items() if value > 1]
+
+ new_doc = []
+ for duplicate_key in duplicates:
+ titles = list(set(doc["title"] for doc in model_doc if doc["local"] == duplicate_key))
+ if len(titles) > 1:
+ raise ValueError(
+ f"{duplicate_key} is present several times in the documentation table of content at "
+ "`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
+ "others."
+ )
+ # Only add this once
+ new_doc.append({"local": duplicate_key, "title": titles[0]})
+
+ # Add none duplicate-keys
+ new_doc.extend([doc for doc in model_doc if counts[doc["local"]] == 1])
+
+ # Sort
+ return sorted(new_doc, key=lambda s: s["title"].lower())
+
+
+def check_model_doc(overwrite=False):
+ with open(PATH_TO_TOC, encoding="utf-8") as f:
+ content = yaml.safe_load(f.read())
+
+ # Get to the API doc
+ api_idx = 0
+ while content[api_idx]["title"] != "API":
+ api_idx += 1
+ api_doc = content[api_idx]["sections"]
+
+ # Then to the model doc
+ model_idx = 0
+ while api_doc[model_idx]["title"] != "Models":
+ model_idx += 1
+
+ model_doc = api_doc[model_idx]["sections"]
+
+ modalities_docs = [(idx, section) for idx, section in enumerate(model_doc) if "sections" in section]
+ diff = False
+ for idx, modality_doc in modalities_docs:
+ old_modality_doc = modality_doc["sections"]
+ new_modality_doc = clean_model_doc_toc(old_modality_doc)
+
+ if old_modality_doc != new_modality_doc:
+ diff = True
+ if overwrite:
+ model_doc[idx]["sections"] = new_modality_doc
+
+ if diff:
+ if overwrite:
+ api_doc[model_idx]["sections"] = model_doc
+ content[api_idx]["sections"] = api_doc
+ with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
+ f.write(yaml.dump(content, allow_unicode=True))
+ else:
+ raise ValueError(
+ "The model doc part of the table of content is not properly sorted, run `make style` to fix this."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
+ args = parser.parse_args()
+
+ check_model_doc(args.fix_and_overwrite)
diff --git a/utils/check_dummies.py b/utils/check_dummies.py
index c1625036c4e3..484aac25452f 100644
--- a/utils/check_dummies.py
+++ b/utils/check_dummies.py
@@ -26,7 +26,7 @@
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
-_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
+_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z_]*\_available\(\)")
DUMMY_CONSTANT = """
@@ -73,6 +73,8 @@ def read_init():
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
if backend is not None:
+ while not lines[line_index].startswith(" else:"):
+ line_index += 1
line_index += 1
objects = []
diff --git a/utils/check_inits.py b/utils/check_inits.py
index 18353581fcff..98d4caf01021 100644
--- a/utils/check_inits.py
+++ b/utils/check_inits.py
@@ -25,10 +25,12 @@
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
+# Catches a one-line _import_struct = {xxx}
+_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
-# Catches a line if is_foo_available
-_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
+# Catches a line if not is_foo_available
+_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@@ -39,6 +41,10 @@
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
+# Catches a line with try:
+_re_try = re.compile(r"^\s*try:")
+# Catches a line with else:
+_re_else = re.compile(r"^\s*else:")
def find_backend(line):
@@ -70,6 +76,14 @@ def parse_init(init_file):
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
+ # If we have everything on a single line, let's deal with it.
+ if _re_one_line_import_struct.search(line):
+ content = _re_one_line_import_struct.search(line).groups()[0]
+ imports = re.findall("\[([^\]]+)\]", content)
+ for imp in imports:
+ objects.extend([obj[1:-1] for obj in imp.split(", ")])
+ line_index += 1
+ continue
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
@@ -81,11 +95,21 @@ def parse_init(init_file):
import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
- # If the line is an if is_backend_available, we grab all objects associated.
+ # If the line is an if not is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
+ # Check if the backend declaration is inside a try block:
+ if _re_try.search(lines[line_index - 1]) is None:
+ backend = None
+
if backend is not None:
line_index += 1
+ # Scroll until we hit the else block of try-except-else
+ while _re_else.search(lines[line_index]) is None:
+ line_index += 1
+
+ line_index += 1
+
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
@@ -130,11 +154,21 @@ def parse_init(init_file):
type_hint_objects = {"none": objects}
# Let's continue with backend-specific objects
while line_index < len(lines):
- # If the line is an if is_backemd_available, we grab all objects associated.
+ # If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
+ # Check if the backend declaration is inside a try block:
+ if _re_try.search(lines[line_index - 1]) is None:
+ backend = None
+
if backend is not None:
line_index += 1
+ # Scroll until we hit the else block of try-except-else
+ while _re_else.search(lines[line_index]) is None:
+ line_index += 1
+
+ line_index += 1
+
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
@@ -225,7 +259,7 @@ def get_transformers_submodules():
if fname == "__init__.py":
continue
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
- submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
+ submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
if len(submodule.split(".")) == 1:
submodules.append(submodule)
return submodules
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 5afaac02a123..d2271e87ebf1 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -36,6 +36,7 @@
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
+ "LongT5Stack",
"RealmBertModel",
"T5Stack",
"TFDPRSpanPredictor",
@@ -45,6 +46,7 @@
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
+ "OPTDecoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
@@ -58,6 +60,7 @@
"DetrDecoderWrapper", # Building part of bigger (tested) model.
"M2M100Encoder", # Building part of bigger (tested) model.
"M2M100Decoder", # Building part of bigger (tested) model.
+ "MCTCTEncoder", # Building part of bigger (tested) model.
"Speech2TextEncoder", # Building part of bigger (tested) model.
"Speech2TextDecoder", # Building part of bigger (tested) model.
"LEDEncoder", # Building part of bigger (tested) model.
@@ -75,6 +78,8 @@
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
+ "MvpDecoderWrapper", # Building part of bigger (tested) model.
+ "MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
@@ -92,6 +97,8 @@
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
+ "OPTDecoderWrapper",
+ "TFSegformerDecodeHead", # Not a regular model.
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
@@ -124,6 +131,7 @@
"ViltForQuestionAnswering",
"ViltForImagesAndTextClassification",
"ViltForImageAndTextRetrieval",
+ "ViltForTokenClassification",
"ViltForMaskedLM",
"XGLMEncoder",
"XGLMDecoder",
@@ -131,6 +139,7 @@
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",
+ "TFSegformerDecodeHead",
"FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",
@@ -138,6 +147,8 @@
"BeitForMaskedImageModeling",
"CLIPTextModel",
"CLIPVisionModel",
+ "GroupViTTextModel",
+ "GroupViTVisionModel",
"TFCLIPTextModel",
"TFCLIPVisionModel",
"FlaxCLIPTextModel",
@@ -146,12 +157,19 @@
"DetrForSegmentation",
"DPRReader",
"FlaubertForQuestionAnswering",
+ "FlavaImageCodebook",
+ "FlavaTextModel",
+ "FlavaImageModel",
+ "FlavaMultimodalModel",
"GPT2DoubleHeadsModel",
"LukeForMaskedLM",
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"OpenAIGPTDoubleHeadsModel",
+ "OwlViTTextModel",
+ "OwlViTVisionModel",
+ "OwlViTForObjectDetection",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
@@ -518,7 +536,8 @@ def check_all_decorator_order():
if len(errors) > 0:
msg = "\n".join(errors)
raise ValueError(
- f"The parameterized decorator (and its variants) should always be first, but this is not the case in the following files:\n{msg}"
+ "The parameterized decorator (and its variants) should always be first, but this is not the case in the"
+ f" following files:\n{msg}"
)
@@ -595,7 +614,6 @@ def find_all_documented_objects():
"absl", # External module
"add_end_docstrings", # Internal, should never have been in the main init.
"add_start_docstrings", # Internal, should never have been in the main init.
- "cached_path", # Internal used for downloading models.
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
"logger", # Internal logger
"logging", # External module
@@ -717,7 +735,7 @@ def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
- with open(file, "r") as f:
+ with open(file, encoding="utf-8") as f:
code = f.read()
docstrings = code.split('"""')
diff --git a/utils/check_table.py b/utils/check_table.py
index d59f3e7b1e5a..96d0cf23d26e 100644
--- a/utils/check_table.py
+++ b/utils/check_table.py
@@ -53,7 +53,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
return "".join(lines[start_index:end_index]), start_index, end_index, lines
-# Add here suffixes that are used to identify models, seperated by |
+# Add here suffixes that are used to identify models, separated by |
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index 456ff4aedc94..375cdb662f3a 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -167,7 +167,7 @@ def sort_imports(file, check_only=True):
"""
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
"""
- with open(file, "r") as f:
+ with open(file, encoding="utf-8") as f:
code = f.read()
if "_import_structure" not in code:
@@ -227,7 +227,7 @@ def sort_imports(file, check_only=True):
return True
else:
print(f"Overwriting {file}.")
- with open(file, "w") as f:
+ with open(file, "w", encoding="utf-8") as f:
f.write("\n".join(main_blocks))
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 8d8c7eccf2cd..1941a7343a6b 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -8,6 +8,7 @@ docs/source/en/model_doc/t5.mdx
docs/source/en/model_doc/t5v1.1.mdx
docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx
+docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation_utils.py
src/transformers/models/albert/modeling_albert.py
src/transformers/models/albert/modeling_tf_albert.py
@@ -21,9 +22,12 @@ src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/convnext/modeling_convnext.py
src/transformers/models/ctrl/modeling_ctrl.py
+src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deit/modeling_deit.py
+src/transformers/models/deit/modeling_tf_deit.py
+src/transformers/models/detr/modeling_detr.py
src/transformers/models/dpt/modeling_dpt.py
src/transformers/models/electra/modeling_electra.py
src/transformers/models/electra/modeling_tf_electra.py
@@ -31,15 +35,27 @@ src/transformers/models/glpn/modeling_glpn.py
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/hubert/modeling_hubert.py
+src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
+src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
+src/transformers/models/longformer/modeling_longformer.py
+src/transformers/models/longformer/modeling_tf_longformer.py
+src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py
+src/transformers/models/mobilevit/modeling_mobilevit.py
+src/transformers/models/opt/modeling_opt.py
+src/transformers/models/opt/modeling_tf_opt.py
+src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py
+src/transformers/models/regnet/modeling_regnet.py
+src/transformers/models/regnet/modeling_tf_regnet.py
src/transformers/models/resnet/modeling_resnet.py
+src/transformers/models/resnet/modeling_tf_resnet.py
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_tf_roberta.py
src/transformers/models/segformer/modeling_segformer.py
@@ -48,11 +64,13 @@ src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
src/transformers/models/speech_to_text/modeling_speech_to_text.py
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
+src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/swin/modeling_swin.py
src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
src/transformers/models/van/modeling_van.py
+src/transformers/models/videomae/modeling_videomae.py
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
src/transformers/models/vit/modeling_vit.py
@@ -60,6 +78,7 @@ src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
-src/transformers/models/wavlm/modeling_wavlm.py
+src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/yolos/modeling_yolos.py
diff --git a/utils/notification_service.py b/utils/notification_service.py
index 47e85d867e5e..4918b4a459ac 100644
--- a/utils/notification_service.py
+++ b/utils/notification_service.py
@@ -98,8 +98,9 @@ def dicts_to_sum(objects: Union[Dict[str, Dict], List[dict]]):
class Message:
- def __init__(self, title: str, model_results: Dict, additional_results: Dict):
+ def __init__(self, title: str, ci_title: str, model_results: Dict, additional_results: Dict):
self.title = title
+ self.ci_title = ci_title
# Failures and success of the modeling tests
self.n_model_success = sum(r["success"] for r in model_results.values())
@@ -158,6 +159,10 @@ def time(self) -> str:
def header(self) -> Dict:
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
+ @property
+ def ci_title_section(self) -> Dict:
+ return {"type": "section", "text": {"type": "mrkdwn", "text": self.ci_title}}
+
@property
def no_failures(self) -> Dict:
return {
@@ -180,7 +185,10 @@ def failures(self) -> Dict:
"type": "section",
"text": {
"type": "plain_text",
- "text": f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in {self.time}.",
+ "text": (
+ f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
+ f" {self.time}."
+ ),
"emoji": True,
},
"accessory": {
@@ -225,15 +233,11 @@ def category_failures(self) -> Dict:
individual_reports.append(key)
header = "Single | Multi | Category\n"
- category_failures_report = header + "\n".join(sorted(individual_reports))
+ category_failures_report = prepare_reports(
+ title="The following modeling categories had failures", header=header, reports=individual_reports
+ )
- return {
- "type": "section",
- "text": {
- "type": "mrkdwn",
- "text": f"The following modeling categories had failures:\n\n```\n{category_failures_report}\n```",
- },
- }
+ return {"type": "section", "text": {"type": "mrkdwn", "text": category_failures_report}}
@property
def model_failures(self) -> Dict:
@@ -294,21 +298,44 @@ def per_model_sum(model_category_dict):
model_header = "Single PT | Multi PT | Single TF | Multi TF | Other | Category\n"
sorted_model_reports = sorted(model_reports, key=lambda s: s.split("] ")[-1])
- model_failures_report = model_header + "\n".join(sorted_model_reports)
+ model_failures_report = prepare_reports(
+ title="These following model modules had failures", header=model_header, reports=sorted_model_reports
+ )
module_header = "Single | Multi | Category\n"
sorted_module_reports = sorted(other_module_reports, key=lambda s: s.split("] ")[-1])
- module_failures_report = module_header + "\n".join(sorted_module_reports)
-
- report = ""
-
- if len(model_reports):
- report += f"These following model modules had failures:\n```\n{model_failures_report}\n```\n\n"
+ module_failures_report = prepare_reports(
+ title="The following non-model modules had failures", header=module_header, reports=sorted_module_reports
+ )
- if len(other_module_reports):
- report += f"The following non-model modules had failures:\n```\n{module_failures_report}\n```\n\n"
+ model_failure_sections = [
+ {"type": "section", "text": {"type": "mrkdwn", "text": model_failures_report}},
+ {"type": "section", "text": {"type": "mrkdwn", "text": module_failures_report}},
+ ]
- return {"type": "section", "text": {"type": "mrkdwn", "text": report}}
+ # Save complete tables (for past CI) - to be uploaded as artifacts
+ if ci_event.startswith("Past CI"):
+ model_failures_report = prepare_reports(
+ title="These following model modules had failures",
+ header=model_header,
+ reports=sorted_model_reports,
+ to_truncate=False,
+ )
+ file_path = os.path.join(os.getcwd(), "test_failure_tables/model_failures_report.txt")
+ with open(file_path, "w", encoding="UTF-8") as fp:
+ fp.write(model_failures_report)
+
+ module_failures_report = prepare_reports(
+ title="The following non-model modules had failures",
+ header=module_header,
+ reports=sorted_module_reports,
+ to_truncate=False,
+ )
+ file_path = os.path.join(os.getcwd(), "test_failure_tables/module_failures_report.txt")
+ with open(file_path, "w", encoding="UTF-8") as fp:
+ fp.write(module_failures_report)
+
+ return model_failure_sections
@property
def additional_failures(self) -> Dict:
@@ -329,25 +356,27 @@ def additional_failures(self) -> Dict:
individual_reports.append(report)
header = "Single | Multi | Category\n"
- failures_report = header + "\n".join(sorted(individual_reports))
+ failures_report = prepare_reports(
+ title="The following non-modeling tests had failures", header=header, reports=individual_reports
+ )
- return {
- "type": "section",
- "text": {
- "type": "mrkdwn",
- "text": f"The following non-modeling tests had failures:\n```\n{failures_report}\n```",
- },
- }
+ return {"type": "section", "text": {"type": "mrkdwn", "text": failures_report}}
@property
def payload(self) -> str:
blocks = [self.header]
+ if self.ci_title:
+ blocks.append(self.ci_title_section)
+
if self.n_model_failures > 0 or self.n_additional_failures > 0:
blocks.append(self.failures)
if self.n_model_failures > 0:
- blocks.extend([self.category_failures, self.model_failures])
+ blocks.append(self.category_failures)
+ for block in self.model_failures:
+ if block["text"]["text"]:
+ blocks.append(block)
if self.n_additional_failures > 0:
blocks.append(self.additional_failures)
@@ -378,7 +407,7 @@ def error_out():
print(json.dumps({"blocks": json.loads(payload)}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text="There was an issue running the tests.",
blocks=payload,
)
@@ -390,14 +419,28 @@ def post(self):
text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed."
self.thread_ts = client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
blocks=self.payload,
text=text,
)
def get_reply_blocks(self, job_name, job_result, failures, device, text):
- if len(failures) > 2500:
- failures = "\n".join(failures.split("\n")[:20]) + "\n\n[Truncated]"
+ """
+ failures: A list with elements of the form {"line": full test name, "trace": error trace}
+ """
+ # `text` must be less than 3001 characters in Slack SDK
+ # keep some room for adding "[Truncated]" when necessary
+ MAX_ERROR_TEXT = 3000 - len("[Truncated]")
+
+ failure_text = ""
+ for idx, error in enumerate(failures):
+ new_text = failure_text + f'*{error["line"]}*\n_{error["trace"]}_\n\n'
+ if len(new_text) > MAX_ERROR_TEXT:
+ # `failure_text` here has length <= 3000
+ failure_text = failure_text + "[Truncated]"
+ break
+ # `failure_text` here has length <= MAX_ERROR_TEXT
+ failure_text = new_text
title = job_name
if device is not None:
@@ -405,17 +448,22 @@ def get_reply_blocks(self, job_name, job_result, failures, device, text):
content = {"type": "section", "text": {"type": "mrkdwn", "text": text}}
- if job_result["job_link"] is not None:
+ # TODO: Make sure we always have a valid job link (or at least a way not to break the report sending)
+ # Currently we get the device from a job's artifact name.
+ # If a device is found, the job name should contain the device type, for example, `XXX (single-gpu)`.
+ # This could be done by adding `machine_type` in a job's `strategy`.
+ # (If `job_result["job_link"][device]` is `None`, we get an error: `... [ERROR] must provide a string ...`)
+ if job_result["job_link"] is not None and job_result["job_link"][device] is not None:
content["accessory"] = {
"type": "button",
"text": {"type": "plain_text", "text": "GitHub Action job", "emoji": True},
- "url": job_result["job_link"],
+ "url": job_result["job_link"][device],
}
return [
{"type": "header", "text": {"type": "plain_text", "text": title.upper(), "emoji": True}},
content,
- {"type": "section", "text": {"type": "mrkdwn", "text": failures}},
+ {"type": "section", "text": {"type": "mrkdwn", "text": failure_text}},
]
def post_reply(self):
@@ -436,7 +484,7 @@ def post_reply(self):
print(json.dumps({"blocks": blocks}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text=f"Results for {job}",
blocks=blocks,
thread_ts=self.thread_ts["ts"],
@@ -459,7 +507,7 @@ def post_reply(self):
print(json.dumps({"blocks": blocks}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text=f"Results for {job}",
blocks=blocks,
thread_ts=self.thread_ts["ts"],
@@ -494,7 +542,7 @@ def retrieve_artifact(name: str, gpu: Optional[str]):
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
if gpu is not None:
- name = f"{gpu}-gpu-docker_{name}"
+ name = f"{gpu}-gpu_{name}"
_artifact = {}
@@ -528,8 +576,8 @@ def add_path(self, path: str, gpu: str = None):
directories = filter(os.path.isdir, os.listdir())
for directory in directories:
- if directory.startswith("single-gpu-docker"):
- artifact_name = directory[len("single-gpu-docker") + 1 :]
+ if directory.startswith("single-gpu"):
+ artifact_name = directory[len("single-gpu") + 1 :]
if artifact_name in _available_artifacts:
_available_artifacts[artifact_name].single_gpu = True
@@ -538,8 +586,8 @@ def add_path(self, path: str, gpu: str = None):
_available_artifacts[artifact_name].add_path(directory, gpu="single")
- elif directory.startswith("multi-gpu-docker"):
- artifact_name = directory[len("multi-gpu-docker") + 1 :]
+ elif directory.startswith("multi-gpu"):
+ artifact_name = directory[len("multi-gpu") + 1 :]
if artifact_name in _available_artifacts:
_available_artifacts[artifact_name].multi_gpu = True
@@ -557,7 +605,90 @@ def add_path(self, path: str, gpu: str = None):
return _available_artifacts
+def prepare_reports(title, header, reports, to_truncate=True):
+ report = ""
+
+ MAX_ERROR_TEXT = 3000 - len("[Truncated]")
+ if not to_truncate:
+ MAX_ERROR_TEXT = float("inf")
+
+ if len(reports) > 0:
+ # `text` must be less than 3001 characters in Slack SDK
+ # keep some room for adding "[Truncated]" when necessary
+
+ for idx in range(len(reports)):
+ _report = header + "\n".join(reports[: idx + 1])
+ new_report = f"{title}:\n```\n{_report}\n```\n"
+ if len(new_report) > MAX_ERROR_TEXT:
+ # `report` here has length <= 3000
+ report = report + "[Truncated]"
+ break
+ report = new_report
+
+ return report
+
+
if __name__ == "__main__":
+
+ org = "huggingface"
+ repo = "transformers"
+ repository_full_name = f"{org}/{repo}"
+
+ # This env. variable is set in workflow file (under the job `send_results`).
+ ci_event = os.environ["CI_EVENT"]
+
+ # To find the PR number in a commit title, for example, `Add AwesomeFormer model (#99999)`
+ pr_number_re = re.compile(r"\(#(\d+)\)$")
+
+ title = f"š¤ Results of the {ci_event} tests."
+ # Add Commit/PR title with a link for push CI
+ # (check the title in 2 env. variables - depending on the CI is triggered via `push` or `workflow_run` event)
+ ci_title_push = os.environ.get("CI_TITLE_PUSH")
+ ci_title_workflow_run = os.environ.get("CI_TITLE_WORKFLOW_RUN")
+ ci_title = ci_title_push if ci_title_push else ci_title_workflow_run
+
+ ci_sha = os.environ.get("CI_SHA")
+
+ ci_url = None
+ if ci_sha:
+ ci_url = f"https://github.com/{repository_full_name}/commit/{ci_sha}"
+
+ if ci_title is not None:
+ if ci_url is None:
+ raise ValueError(
+ "When a title is found (`ci_title`), it means a `push` event or a `workflow_run` even (triggered by "
+ "another `push` event), and the commit SHA has to be provided in order to create the URL to the "
+ "commit page."
+ )
+ ci_title = ci_title.strip().split("\n")[0].strip()
+
+ # Retrieve the PR title and author login to complete the report
+ commit_number = ci_url.split("/")[-1]
+ ci_detail_url = f"https://api.github.com/repos/{repository_full_name}/commits/{commit_number}"
+ ci_details = requests.get(ci_detail_url).json()
+ ci_author = ci_details["author"]["login"]
+
+ merged_by = None
+ # Find the PR number (if any) and change the url to the actual PR page.
+ numbers = pr_number_re.findall(ci_title)
+ if len(numbers) > 0:
+ pr_number = numbers[0]
+ ci_detail_url = f"https://api.github.com/repos/{repository_full_name}/pulls/{pr_number}"
+ ci_details = requests.get(ci_detail_url).json()
+
+ ci_author = ci_details["user"]["login"]
+ ci_url = f"https://github.com/{repository_full_name}/pull/{pr_number}"
+
+ merged_by = ci_details["merged_by"]["login"]
+
+ if merged_by is None:
+ ci_title = f"<{ci_url}|{ci_title}>\nAuthor: {ci_author}"
+ else:
+ ci_title = f"<{ci_url}|{ci_title}>\nAuthor: {ci_author} | Merged by: {merged_by}"
+
+ else:
+ ci_title = ""
+
arguments = sys.argv[1:][0]
try:
models = ast.literal_eval(arguments)
@@ -593,6 +724,7 @@ def add_path(self, path: str, gpu: str = None):
"success": 0,
"time_spent": "",
"failures": {},
+ "job_link": {},
}
for model in models
if f"run_all_tests_gpu_{model}_test_reports" in available_artifacts
@@ -600,15 +732,24 @@ def add_path(self, path: str, gpu: str = None):
unclassified_model_failures = []
+ # This prefix is used to get job links below. For past CI, we use `workflow_call`, which changes the job names from
+ # `Model tests (...)` to `PyTorch 1.5 / Model tests (...)` for example.
+ job_name_prefix = ""
+ if ci_event.startswith("Past CI - "):
+ framework, version = ci_event.replace("Past CI - ", "").split("-")
+ framework = "PyTorch" if framework == "pytorch" else "TensorFlow"
+ job_name_prefix = f"{framework} {version}"
+
for model in model_results.keys():
for artifact_path in available_artifacts[f"run_all_tests_gpu_{model}_test_reports"].paths:
artifact = retrieve_artifact(artifact_path["name"], artifact_path["gpu"])
if "stats" in artifact:
# Link to the GitHub Action job
- model_results[model]["job_link"] = github_actions_job_links.get(
- f"Model tests ({model}, {artifact_path['gpu']}-gpu-docker)"
- )
-
+ # The job names use `matrix.folder` which contain things like `models/bert` instead of `models_bert`
+ job_name = f"Model tests ({model.replace('models_', 'models/')}, {artifact_path['gpu']}-gpu)"
+ if job_name_prefix:
+ job_name = f"{job_name_prefix} / {job_name}"
+ model_results[model]["job_link"][artifact_path["gpu"]] = github_actions_job_links.get(job_name)
failed, success, time_spent = handle_test_results(artifact["stats"])
model_results[model]["success"] += success
model_results[model]["time_spent"] += time_spent[1:-1] + ", "
@@ -622,16 +763,16 @@ def add_path(self, path: str, gpu: str = None):
line = line.split()[0].replace("\n", "")
if artifact_path["gpu"] not in model_results[model]["failures"]:
- model_results[model]["failures"][artifact_path["gpu"]] = ""
+ model_results[model]["failures"][artifact_path["gpu"]] = []
- model_results[model]["failures"][
- artifact_path["gpu"]
- ] += f"*{line}*\n_{stacktraces.pop(0)}_\n\n"
+ model_results[model]["failures"][artifact_path["gpu"]].append(
+ {"line": line, "trace": stacktraces.pop(0)}
+ )
- if re.search("_tf_", line):
+ if re.search("test_modeling_tf_", line):
model_results[model]["failed"]["TensorFlow"][artifact_path["gpu"]] += 1
- elif re.search("_flax_", line):
+ elif re.search("test_modeling_flax_", line):
model_results[model]["failed"]["Flax"][artifact_path["gpu"]] += 1
elif re.search("test_modeling", line):
@@ -664,6 +805,11 @@ def add_path(self, path: str, gpu: str = None):
"Torch CUDA extension tests": "run_tests_torch_cuda_extensions_gpu_test_reports",
}
+ if ci_event == "push":
+ del additional_files["Examples directory"]
+ del additional_files["PyTorch pipelines"]
+ del additional_files["TensorFlow pipelines"]
+
additional_results = {
key: {
"failed": {"unclassified": 0, "single": 0, "multi": 0},
@@ -671,7 +817,7 @@ def add_path(self, path: str, gpu: str = None):
"time_spent": "",
"error": False,
"failures": {},
- "job_link": github_actions_job_links.get(key),
+ "job_link": {},
}
for key in additional_files.keys()
}
@@ -685,9 +831,12 @@ def add_path(self, path: str, gpu: str = None):
for artifact_path in available_artifacts[additional_files[key]].paths:
if artifact_path["gpu"] is not None:
- additional_results[key]["job_link"] = github_actions_job_links.get(
- f"{key} ({artifact_path['gpu']}-gpu-docker)"
+ additional_results[key]["job_link"][artifact_path["gpu"]] = github_actions_job_links.get(
+ f"{key} ({artifact_path['gpu']}-gpu)"
)
+ else:
+ additional_results[key]["job_link"][artifact_path["gpu"]] = github_actions_job_links.get(key)
+
artifact = retrieve_artifact(artifact_path["name"], artifact_path["gpu"])
stacktraces = handle_stacktraces(artifact["failures_line"])
@@ -706,13 +855,15 @@ def add_path(self, path: str, gpu: str = None):
line = line.split()[0].replace("\n", "")
if artifact_path["gpu"] not in additional_results[key]["failures"]:
- additional_results[key]["failures"][artifact_path["gpu"]] = ""
+ additional_results[key]["failures"][artifact_path["gpu"]] = []
- additional_results[key]["failures"][
- artifact_path["gpu"]
- ] += f"*{line}*\n_{stacktraces.pop(0)}_\n\n"
+ additional_results[key]["failures"][artifact_path["gpu"]].append(
+ {"line": line, "trace": stacktraces.pop(0)}
+ )
- message = Message("š¤ Results of the scheduled tests.", model_results, additional_results)
+ message = Message(title, ci_title, model_results, additional_results)
- message.post()
- message.post_reply()
+ # send report only if there is any failure (for push CI)
+ if message.n_failures or ci_event != "push":
+ message.post()
+ message.post_reply()
diff --git a/utils/notification_service_deprecated.py b/utils/notification_service_deprecated.py
deleted file mode 100644
index b14bff175192..000000000000
--- a/utils/notification_service_deprecated.py
+++ /dev/null
@@ -1,217 +0,0 @@
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Old script for Slack's notification service. Still here as the entire suite has not been moved to the newer implem.
-
-import os
-import re
-import sys
-
-from slack_sdk import WebClient
-
-
-def handle_test_results(test_results):
- expressions = test_results.split(" ")
-
- failed = 0
- success = 0
-
- # When the output is short enough, the output is surrounded by = signs: "== OUTPUT =="
- # When it is too long, those signs are not present.
- time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1]
-
- for i, expression in enumerate(expressions):
- if "failed" in expression:
- failed += int(expressions[i - 1])
- if "passed" in expression:
- success += int(expressions[i - 1])
-
- return failed, success, time_spent
-
-
-def format_for_slack(total_results, results, scheduled: bool, title: str):
- print(total_results, results)
- header = {
- "type": "header",
- "text": {
- "type": "plain_text",
- "text": title,
- "emoji": True,
- },
- }
-
- if total_results["failed"] > 0:
- total = {
- "type": "section",
- "fields": [
- {"type": "mrkdwn", "text": f"*Failures:*\nā {total_results['failed']} failures."},
- {"type": "mrkdwn", "text": f"*Passed:*\nā
{total_results['success']} tests passed."},
- ],
- }
- else:
- total = {
- "type": "section",
- "fields": [
- {"type": "mrkdwn", "text": "\nš All tests passed."},
- ],
- }
-
- blocks = [header, total]
-
- if total_results["failed"] > 0:
- for key, result in results.items():
- print(key, result)
- blocks.append({"type": "header", "text": {"type": "plain_text", "text": key, "emoji": True}})
- blocks.append(
- {
- "type": "section",
- "fields": [
- {
- "type": "mrkdwn",
- "text": f"*Results:*\n{result['failed']} failed, {result['success']} passed.",
- },
- {"type": "mrkdwn", "text": f"*Time spent:*\n{result['time_spent']}"},
- ],
- }
- )
- elif not scheduled:
- for key, result in results.items():
- blocks.append(
- {"type": "section", "fields": [{"type": "mrkdwn", "text": f"*{key}*\n{result['time_spent']}."}]}
- )
-
- footer = {
- "type": "section",
- "text": {
- "type": "mrkdwn",
- "text": f"",
- },
- }
-
- blocks.append(footer)
-
- blocks = {"blocks": blocks}
-
- return blocks
-
-
-if __name__ == "__main__":
- arguments = sys.argv[1:]
-
- if "scheduled" in arguments:
- arguments.remove("scheduled")
- scheduled = True
- else:
- scheduled = False
-
- if scheduled:
- # The scheduled run has several artifacts for each job.
- file_paths = {
- "TF Single GPU": {
- "common": "run_all_tests_tf_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_tf_gpu_test_reports/[].txt",
- },
- "Torch Single GPU": {
- "common": "run_all_tests_torch_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_torch_gpu_test_reports/[].txt",
- "examples": "run_all_tests_torch_gpu_test_reports/[].txt",
- },
- "TF Multi GPU": {
- "common": "run_all_tests_tf_multi_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_tf_multi_gpu_test_reports/[].txt",
- },
- "Torch Multi GPU": {
- "common": "run_all_tests_torch_multi_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_torch_multi_gpu_test_reports/[].txt",
- },
- "Torch Cuda Extensions Single GPU": {"common": "run_tests_torch_cuda_extensions_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Multi GPU": {
- "common": "run_tests_torch_cuda_extensions_multi_gpu_test_reports/[].txt"
- },
- }
- else:
- file_paths = {
- "TF Single GPU": {"common": "run_all_tests_tf_gpu_test_reports/[].txt"},
- "Torch Single GPU": {"common": "run_all_tests_torch_gpu_test_reports/[].txt"},
- "TF Multi GPU": {"common": "run_all_tests_tf_multi_gpu_test_reports/[].txt"},
- "Torch Multi GPU": {"common": "run_all_tests_torch_multi_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Single GPU": {"common": "run_tests_torch_cuda_extensions_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Multi GPU": {
- "common": "run_tests_torch_cuda_extensions_multi_gpu_test_reports/[].txt"
- },
- }
-
- client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"])
-
- if not scheduled:
- channel_id = os.environ["CI_SLACK_CHANNEL_ID"]
- elif scheduled and len(arguments):
- channel_id = os.environ["CI_SLACK_CHANNEL_ID_PAST_FUTURE"]
- else:
- channel_id = os.environ["CI_SLACK_CHANNEL_ID_DAILY"]
-
- if scheduled:
- title = "š¤ Results of the scheduled tests."
- else:
- title = "š¤ Self-push results"
-
- if len(arguments):
- title = f"{arguments} " + title
-
- try:
- results = {}
- for job, file_dict in file_paths.items():
-
- # Single return value for failed/success across steps of a same job
- results[job] = {"failed": 0, "success": 0, "time_spent": "", "failures": ""}
-
- for key, file_path in file_dict.items():
- try:
- with open(file_path.replace("[]", "stats")) as f:
- failed, success, time_spent = handle_test_results(f.read())
- results[job]["failed"] += failed
- results[job]["success"] += success
- results[job]["time_spent"] += time_spent[1:-1] + ", "
- with open(file_path.replace("[]", "summary_short")) as f:
- for line in f:
- if re.search("FAILED", line):
- results[job]["failures"] += line
- except FileNotFoundError:
- print("Artifact was not found, job was probably canceled.")
-
- # Remove the trailing ", "
- results[job]["time_spent"] = results[job]["time_spent"][:-2]
-
- test_results_keys = ["failed", "success"]
- total = {"failed": 0, "success": 0}
- for job, job_result in results.items():
- for result_key in test_results_keys:
- total[result_key] += job_result[result_key]
-
- if total["failed"] != 0 or scheduled:
- to_be_sent_to_slack = format_for_slack(total, results, scheduled, title)
-
- result = client.chat_postMessage(
- channel=channel_id,
- blocks=to_be_sent_to_slack["blocks"],
- )
-
- for job, job_result in results.items():
- if len(job_result["failures"]):
- client.chat_postMessage(
- channel=channel_id, text=f"{job}\n{job_result['failures']}", thread_ts=result["ts"]
- )
-
- except Exception as e:
- # Voluntarily catch every exception and send it to Slack.
- raise Exception(f"Setup error: no artifacts were found. Error: {e}") from e
diff --git a/utils/notification_service_doc_tests.py b/utils/notification_service_doc_tests.py
index 58ceb567adbd..d02b08b605e1 100644
--- a/utils/notification_service_doc_tests.py
+++ b/utils/notification_service_doc_tests.py
@@ -118,7 +118,10 @@ def failures(self) -> Dict:
"type": "section",
"text": {
"type": "plain_text",
- "text": f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in {self.time}.",
+ "text": (
+ f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
+ f" {self.time}."
+ ),
"emoji": True,
},
"accessory": {
@@ -286,7 +289,7 @@ def retrieve_artifact(name: str):
files = os.listdir(name)
for file in files:
try:
- with open(os.path.join(name, file)) as f:
+ with open(os.path.join(name, file), encoding="utf-8") as f:
_artifact[file.split(".")[0]] = f.read()
except UnicodeDecodeError as e:
raise ValueError(f"Could not open {os.path.join(name, file)}.") from e
diff --git a/utils/past_ci_versions.py b/utils/past_ci_versions.py
new file mode 100644
index 000000000000..854127b34173
--- /dev/null
+++ b/utils/past_ci_versions.py
@@ -0,0 +1,141 @@
+import argparse
+import os
+
+
+past_versions_testing = {
+ "pytorch": {
+ "1.11": {
+ "torch": "1.11.0",
+ "torchvision": "0.12.0",
+ "torchaudio": "0.11.0",
+ "python": 3.9,
+ "cuda": "cu113",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0"
+ " --extra-index-url https://download.pytorch.org/whl/cu113"
+ ),
+ },
+ "1.10": {
+ "torch": "1.10.2",
+ "torchvision": "0.11.3",
+ "torchaudio": "0.10.2",
+ "python": 3.9,
+ "cuda": "cu113",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2"
+ " --extra-index-url https://download.pytorch.org/whl/cu113"
+ ),
+ },
+ # torchaudio < 0.10 has no CUDA-enabled binary distributions
+ "1.9": {
+ "torch": "1.9.1",
+ "torchvision": "0.10.1",
+ "torchaudio": "0.9.1",
+ "python": 3.9,
+ "cuda": "cu111",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1"
+ " --extra-index-url https://download.pytorch.org/whl/cu111"
+ ),
+ },
+ "1.8": {
+ "torch": "1.8.1",
+ "torchvision": "0.9.1",
+ "torchaudio": "0.8.1",
+ "python": 3.9,
+ "cuda": "cu111",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1"
+ " --extra-index-url https://download.pytorch.org/whl/cu111"
+ ),
+ },
+ "1.7": {
+ "torch": "1.7.1",
+ "torchvision": "0.8.2",
+ "torchaudio": "0.7.2",
+ "python": 3.9,
+ "cuda": "cu110",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2"
+ " --extra-index-url https://download.pytorch.org/whl/cu110"
+ ),
+ },
+ "1.6": {
+ "torch": "1.6.0",
+ "torchvision": "0.7.0",
+ "torchaudio": "0.6.0",
+ "python": 3.8,
+ "cuda": "cu101",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.6.0 torchvision==0.7.0 torchaudio==0.6.0"
+ " --extra-index-url https://download.pytorch.org/whl/cu101"
+ ),
+ },
+ "1.5": {
+ "torch": "1.5.1",
+ "torchvision": "0.6.1",
+ "torchaudio": "0.5.1",
+ "python": 3.8,
+ "cuda": "cu101",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.5.1 torchvision==0.6.1 torchaudio==0.5.1"
+ " --extra-index-url https://download.pytorch.org/whl/cu101"
+ ),
+ },
+ "1.4": {
+ "torch": "1.4.0",
+ "torchvision": "0.5.0",
+ "torchaudio": "0.4.0",
+ "python": 3.8,
+ "cuda": "cu100",
+ "install": (
+ "python3 -m pip install --no-cache-dir -U torch==1.4.0 torchvision==0.5.0 torchaudio==0.4.0"
+ " --extra-index-url https://download.pytorch.org/whl/cu100"
+ ),
+ },
+ },
+ "tensorflow": {
+ "2.8": {
+ "tensorflow": "2.8.2",
+ "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.8.2",
+ },
+ "2.7": {
+ "tensorflow": "2.7.3",
+ "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.7.3",
+ },
+ "2.6": {
+ "tensorflow": "2.6.5",
+ "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.6.5",
+ },
+ "2.5": {
+ "tensorflow": "2.5.3",
+ "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.5.3",
+ },
+ # need another `nvidia:cuda` docker image, otherwise GPU not working
+ "2.4": {
+ "tensorflow": "2.4.4",
+ "install": "python3 -m pip install --no-cache-dir -U tensorflow==2.4.4",
+ # This should be specified as a docker build argument.
+ # We keep the information here for reference only.
+ "base_docker": "nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04",
+ },
+ },
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Choose the framework and version to install")
+ parser.add_argument("--framework", help="The framework to install. Should be `torch` or `tensorflow`", type=str)
+ parser.add_argument("--version", help="The version of the framework to install.", type=str)
+ args = parser.parse_args()
+
+ info = past_versions_testing[args.framework][args.version]
+
+ os.system(f'echo "export INSTALL_CMD=\'{info["install"]}\'" >> ~/.profile')
+ print(f'echo "export INSTALL_CMD=\'{info["install"]}\'" >> ~/.profile')
+
+ cuda = ""
+ if args.framework == "pytorch":
+ cuda = info["cuda"]
+ os.system(f"echo \"export CUDA='{cuda}'\" >> ~/.profile")
+ print(f"echo \"export CUDA='{cuda}'\" >> ~/.profile")
diff --git a/utils/prepare_for_doc_test.py b/utils/prepare_for_doc_test.py
index 123219954cd0..c55f3540d994 100644
--- a/utils/prepare_for_doc_test.py
+++ b/utils/prepare_for_doc_test.py
@@ -28,7 +28,7 @@
When debugging the doc tests locally, please make sure to
always run:
- ```python utils/prepare_for_doc_test.py src doc```
+ ```python utils/prepare_for_doc_test.py src docs```
before running the doc tests:
@@ -36,7 +36,7 @@
Afterwards you should revert the changes by running
- ```python utils/prepare_for_doc_test.py src doc --remove_new_line```
+ ```python utils/prepare_for_doc_test.py src docs --remove_new_line```
"""
import argparse
@@ -92,6 +92,9 @@ def process_doc_file(code_file, add_new_line=True):
# fmt: off
splits = code.split("```")
+ if len(splits) % 2 != 1:
+ raise ValueError("The number of occurrences of ``` should be an even number.")
+
splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)]
clean_code = "```".join(splits)
# fmt: on
diff --git a/utils/print_env.py b/utils/print_env.py
new file mode 100644
index 000000000000..443ed6eab6c4
--- /dev/null
+++ b/utils/print_env.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# this script dumps information about the environment
+
+import os
+import sys
+
+import transformers
+
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+print("Python version:", sys.version)
+print("transformers version:", transformers.__version__)
+
+try:
+ import torch
+
+ print("Torch version:", torch.__version__)
+ print("Cuda available:", torch.cuda.is_available())
+ print("Cuda version:", torch.version.cuda)
+ print("CuDNN version:", torch.backends.cudnn.version())
+ print("Number of GPUs available:", torch.cuda.device_count())
+ print("NCCL version:", torch.cuda.nccl.version())
+except ImportError:
+ print("Torch version:", None)
+
+try:
+ import deepspeed
+
+ print("DeepSpeed version:", deepspeed.__version__)
+except ImportError:
+ print("DeepSpeed version:", None)
+
+try:
+ import tensorflow as tf
+
+ print("TensorFlow version:", tf.__version__)
+ print("TF GPUs available:", bool(tf.config.list_physical_devices("GPU")))
+ print("Number of TF GPUs available:", len(tf.config.list_physical_devices("GPU")))
+except ImportError:
+ print("TensorFlow version:", None)
diff --git a/utils/print_env_pt.py b/utils/print_env_pt.py
deleted file mode 100755
index 94451541f646..000000000000
--- a/utils/print_env_pt.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#!/usr/bin/env python3
-
-# coding=utf-8
-# Copyright 2020 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# this script dumps information about the environment
-
-import torch
-
-
-print("Torch version:", torch.__version__)
-print("Cuda available:", torch.cuda.is_available())
-print("Cuda version:", torch.version.cuda)
-print("CuDNN version:", torch.backends.cudnn.version())
-print("Number of GPUs available:", torch.cuda.device_count())
-print("NCCL version:", torch.cuda.nccl.version())
diff --git a/utils/release.py b/utils/release.py
index 5a9c15f6ae06..3bb75f0bebf4 100644
--- a/utils/release.py
+++ b/utils/release.py
@@ -123,7 +123,7 @@ def pre_release_work(patch=False):
print(f"Updating version to {version}.")
global_version_update(version, patch=patch)
if not patch:
- print("Cleaning main README")
+ print("Cleaning main README, don't forget to run `make fix-copies`.")
clean_main_ref_in_model_list()
@@ -141,6 +141,8 @@ def post_release_work():
print(f"Updating version to {version}.")
global_version_update(version)
+ print("Cleaning main README, don't forget to run `make fix-copies`.")
+ clean_main_ref_in_model_list()
if __name__ == "__main__":
diff --git a/utils/sort_auto_mappings.py b/utils/sort_auto_mappings.py
new file mode 100644
index 000000000000..ef985dc43cd4
--- /dev/null
+++ b/utils/sort_auto_mappings.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import re
+
+
+PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
+
+
+# re pattern that matches mapping introductions:
+# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
+_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
+# re pattern that matches identifiers in mappings
+_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
+
+
+def sort_auto_mapping(fname, overwrite: bool = False):
+ with open(fname, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ new_lines = []
+ line_idx = 0
+ while line_idx < len(lines):
+ if _re_intro_mapping.search(lines[line_idx]) is not None:
+ indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
+ # Start of a new mapping!
+ while not lines[line_idx].startswith(" " * indent + "("):
+ new_lines.append(lines[line_idx])
+ line_idx += 1
+
+ blocks = []
+ while lines[line_idx].strip() != "]":
+ # Blocks either fit in one line or not
+ if lines[line_idx].strip() == "(":
+ start_idx = line_idx
+ while not lines[line_idx].startswith(" " * indent + ")"):
+ line_idx += 1
+ blocks.append("\n".join(lines[start_idx : line_idx + 1]))
+ else:
+ blocks.append(lines[line_idx])
+ line_idx += 1
+
+ # Sort blocks by their identifiers
+ blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
+ new_lines += blocks
+ else:
+ new_lines.append(lines[line_idx])
+ line_idx += 1
+
+ if overwrite:
+ with open(fname, "w", encoding="utf-8") as f:
+ f.write("\n".join(new_lines))
+ elif "\n".join(new_lines) != content:
+ return True
+
+
+def sort_all_auto_mappings(overwrite: bool = False):
+ fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
+ diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
+
+ if not overwrite and any(diffs):
+ failures = [f for f, d in zip(fnames, diffs) if d]
+ raise ValueError(
+ f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
+ " this."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
+ args = parser.parse_args()
+
+ sort_all_auto_mappings(not args.check_only)
diff --git a/utils/test_module/custom_pipeline.py b/utils/test_module/custom_pipeline.py
new file mode 100644
index 000000000000..4c7928b1ccd1
--- /dev/null
+++ b/utils/test_module/custom_pipeline.py
@@ -0,0 +1,33 @@
+import numpy as np
+
+from transformers import Pipeline
+
+
+def softmax(outputs):
+ maxes = np.max(outputs, axis=-1, keepdims=True)
+ shifted_exp = np.exp(outputs - maxes)
+ return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
+
+
+class PairClassificationPipeline(Pipeline):
+ def _sanitize_parameters(self, **kwargs):
+ preprocess_kwargs = {}
+ if "second_text" in kwargs:
+ preprocess_kwargs["second_text"] = kwargs["second_text"]
+ return preprocess_kwargs, {}, {}
+
+ def preprocess(self, text, second_text=None):
+ return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
+
+ def _forward(self, model_inputs):
+ return self.model(**model_inputs)
+
+ def postprocess(self, model_outputs):
+ logits = model_outputs.logits[0].numpy()
+ probabilities = softmax(logits)
+
+ best_class = np.argmax(probabilities)
+ label = self.model.config.id2label[best_class]
+ score = probabilities[best_class].item()
+ logits = logits.tolist()
+ return {"label": label, "score": score, "logits": logits}
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index f733f301e8dc..329d248de3c0 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -15,6 +15,7 @@
import argparse
import collections
+import json
import os
import re
from contextlib import contextmanager
@@ -65,6 +66,32 @@ def clean_code(content):
return "\n".join(lines_to_keep)
+def get_all_tests():
+ """
+ Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`.
+
+ - folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded.
+ - folders under `tests/models`: `bert`, `gpt2`, etc.
+ - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
+ """
+ test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")
+
+ # test folders/files directly under `tests` folder
+ tests = os.listdir(test_root_dir)
+ tests = sorted(
+ list(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))
+ )
+
+ # model specific test folders
+ model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
+ model_test_folders = sorted(list(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders])))
+
+ tests.remove("tests/models")
+ tests = model_test_folders + tests
+
+ return tests
+
+
def diff_is_docstring_only(repo, branching_point, filename):
"""
Check if the diff is only in docstrings in a filename.
@@ -199,20 +226,78 @@ def get_test_dependencies(test_fname):
relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content)
relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp]
- # Removes the double trailing '..' for parent imports, and creates an absolute path from the root dir with
- # `tests` as a prefix.
- parent_imports = [imp.strip(".") for imp in relative_imports if ".." in imp]
- parent_imports = [os.path.join("tests", f"{test.replace('.', os.path.sep)}.py") for test in parent_imports]
-
- # Removes the single trailing '.' for current dir imports, and creates an absolute path from the root dir with
- # tests/{module_name} as a prefix.
- current_dir_imports = [imp.strip(".") for imp in relative_imports if ".." not in imp]
- directory = os.path.sep.join(test_fname.split(os.path.sep)[:-1])
- current_dir_imports = [
- os.path.join(directory, f"{test.replace('.', os.path.sep)}.py") for test in current_dir_imports
+ def _convert_relative_import_to_file(relative_import):
+ level = 0
+ while relative_import.startswith("."):
+ level += 1
+ relative_import = relative_import[1:]
+
+ directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level])
+ return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py")
+
+ dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports]
+ return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))]
+
+
+def create_reverse_dependency_tree():
+ """
+ Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
+ """
+ modules = [
+ str(f.relative_to(PATH_TO_TRANFORMERS))
+ for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
]
+ module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]
- return [f for f in [*parent_imports, *current_dir_imports] if os.path.isfile(f)]
+ tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
+ test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]
+
+ return module_edges + test_edges
+
+
+def get_tree_starting_at(module, edges):
+ """
+ Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
+ starting at module], [list of edges starting at the preceding level], ...]
+ """
+ vertices_seen = [module]
+ new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
+ tree = [module]
+ while len(new_edges) > 0:
+ tree.append(new_edges)
+ final_vertices = list(set(edge[1] for edge in new_edges))
+ vertices_seen.extend(final_vertices)
+ new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]
+
+ return tree
+
+
+def print_tree_deps_of(module, all_edges=None):
+ """
+ Prints the tree of modules depending on a given module.
+ """
+ if all_edges is None:
+ all_edges = create_reverse_dependency_tree()
+ tree = get_tree_starting_at(module, all_edges)
+
+ # The list of lines is a list of tuples (line_to_be_printed, module)
+ # Keeping the modules lets us know where to insert each new lines in the list.
+ lines = [(tree[0], tree[0])]
+ for index in range(1, len(tree)):
+ edges = tree[index]
+ start_edges = set([edge[0] for edge in edges])
+
+ for start in start_edges:
+ end_edges = set([edge[1] for edge in edges if edge[0] == start])
+ # We will insert all those edges just after the line showing start.
+ pos = 0
+ while lines[pos][1] != start:
+ pos += 1
+ lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]
+
+ for line in lines:
+ # We don't print the refs that where just here to help build lines.
+ print(line[0])
def create_reverse_dependency_map():
@@ -268,7 +353,7 @@ def create_reverse_dependency_map():
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
"feature_extraction_utils.py": "test_feature_extraction_common.py",
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
- "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
+ "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
"utils/hub.py": "utils/test_file_utils.py",
"modelcard.py": "utils/test_model_card.py",
"modeling_flax_utils.py": "test_modeling_flax_common.py",
@@ -292,6 +377,7 @@ def create_reverse_dependency_map():
],
"optimization.py": "optimization/test_optimization.py",
"optimization_tf.py": "optimization/test_optimization_tf.py",
+ "pipelines/__init__.py": "pipelines/test_pipelines_*.py",
"pipelines/base.py": "pipelines/test_pipelines_*.py",
"pipelines/text2text_generation.py": [
"pipelines/test_pipelines_text2text_generation.py",
@@ -441,7 +527,7 @@ def sanity_check():
)
-def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
+def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
@@ -495,6 +581,42 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
with open(output_file, "w", encoding="utf-8") as f:
f.write(" ".join(test_files_to_run))
+ # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]
+
+ # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
+ # contains `tests` (i.e. when `setup.py` is changed).
+ if "tests" in test_files_to_run:
+ test_files_to_run = get_all_tests()
+
+ if json_output_file is not None:
+ test_map = {}
+ for test_file in test_files_to_run:
+ # `test_file` is a path to a test folder/file, starting with `tests/`. For example,
+ # - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
+ # - `tests/trainer/test_trainer.py` or `tests/trainer`
+ # - `tests/test_modeling_common.py`
+ names = test_file.split(os.path.sep)
+ if names[1] == "models":
+ # take the part like `models/bert` for modeling tests
+ key = "/".join(names[1:3])
+ elif len(names) > 2 or not test_file.endswith(".py"):
+ # test folders under `tests` or python files under them
+ # take the part like tokenization, `pipeline`, etc. for other test categories
+ key = "/".join(names[1:2])
+ else:
+ # common test files directly under `tests/`
+ key = "common"
+
+ if key not in test_map:
+ test_map[key] = []
+ test_map[key].append(test_file)
+
+ # sort the keys & values
+ keys = sorted(test_map.keys())
+ test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
+ json.dump(test_map, fp, ensure_ascii=False)
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -504,6 +626,12 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
parser.add_argument(
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
)
+ parser.add_argument(
+ "--json_output_file",
+ type=str,
+ default="test_map.json",
+ help="Where to store the tests to run in a dictionary format mapping test categories to test files",
+ )
parser.add_argument(
"--diff_with_last_commit",
action="store_true",
@@ -516,8 +644,16 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
default=["tests"],
help="Only keep the test files matching one of those filters.",
)
+ parser.add_argument(
+ "--print_dependencies_of",
+ type=str,
+ help="Will only print the tree of modules depending on the file passed.",
+ default=None,
+ )
args = parser.parse_args()
- if args.sanity_check:
+ if args.print_dependencies_of is not None:
+ print_tree_deps_of(args.print_dependencies_of)
+ elif args.sanity_check:
sanity_check()
else:
repo = Repo(PATH_TO_TRANFORMERS)
@@ -528,7 +664,12 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
diff_with_last_commit = True
try:
- infer_tests_to_run(args.output_file, diff_with_last_commit=diff_with_last_commit, filters=args.filters)
+ infer_tests_to_run(
+ args.output_file,
+ diff_with_last_commit=diff_with_last_commit,
+ filters=args.filters,
+ json_output_file=args.json_output_file,
+ )
except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
with open(args.output_file, "w", encoding="utf-8") as f:
diff --git a/valohai.yaml b/valohai.yaml
deleted file mode 100644
index 14441e27d02d..000000000000
--- a/valohai.yaml
+++ /dev/null
@@ -1,91 +0,0 @@
----
-
-- step:
- name: Execute python examples/text-classification/run_glue.py
- image: pytorch/pytorch:nightly-devel-cuda10.0-cudnn7
- command:
- - python /valohai/repository/utils/download_glue_data.py --data_dir=/glue_data
- - pip install -e .
- - pip install -r examples/requirements.txt
- - python examples/text-classification/run_glue.py --do_train --data_dir=/glue_data/{parameter-value:task_name} {parameters}
- parameters:
- - name: model_type
- pass-as: --model_type={v}
- type: string
- default: bert
- - name: model_name_or_path
- pass-as: --model_name_or_path={v}
- type: string
- default: bert-base-uncased
- - name: task_name
- pass-as: --task_name={v}
- type: string
- default: MRPC
- - name: max_seq_length
- pass-as: --max_seq_length={v}
- description: The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.
- type: integer
- default: 128
- - name: per_gpu_train_batch_size
- pass-as: --per_gpu_train_batch_size={v}
- description: Batch size per GPU/CPU for training.
- type: integer
- default: 8
- - name: per_gpu_eval_batch_size
- pass-as: --per_gpu_eval_batch_size={v}
- description: Batch size per GPU/CPU for evaluation.
- type: integer
- default: 8
- - name: gradient_accumulation_steps
- pass-as: --gradient_accumulation_steps={v}
- description: Number of updates steps to accumulate before performing a backward/update pass.
- type: integer
- default: 1
- - name: learning_rate
- pass-as: --learning_rate={v}
- description: The initial learning rate for Adam.
- type: float
- default: 0.00005
- - name: adam_epsilon
- pass-as: --adam_epsilon={v}
- description: Epsilon for Adam optimizer.
- type: float
- default: 0.00000001
- - name: max_grad_norm
- pass-as: --max_grad_norm={v}
- description: Max gradient norm.
- type: float
- default: 1.0
- - name: num_train_epochs
- pass-as: --num_train_epochs={v}
- description: Total number of training epochs to perform.
- type: integer
- default: 3
- - name: max_steps
- pass-as: --max_steps={v}
- description: If > 0, set total number of training steps to perform. Override num_train_epochs.
- type: integer
- default: -1
- - name: warmup_steps
- pass-as: --warmup_steps={v}
- description: Linear warmup over warmup_steps.
- type: integer
- default: -1
- - name: logging_steps
- pass-as: --logging_steps={v}
- description: Log every X updates steps.
- type: integer
- default: 25
- - name: save_steps
- pass-as: --save_steps={v}
- description: Save checkpoint every X updates steps.
- type: integer
- default: -1
- - name: output_dir
- pass-as: --output_dir={v}
- type: string
- default: /valohai/outputs
- - name: evaluation_strategy
- description: The evaluation strategy to use.
- type: string
- default: steps